prithivMLmods commited on
Commit
dd72372
·
verified ·
1 Parent(s): 3ff2b9e

Delete ultrashape

Browse files
Files changed (44) hide show
  1. ultrashape/__init__.py +0 -17
  2. ultrashape/data/objaverse_dit.py +0 -331
  3. ultrashape/data/objaverse_vae.py +0 -262
  4. ultrashape/data/utils.py +0 -193
  5. ultrashape/models/__init__.py +0 -27
  6. ultrashape/models/autoencoders/__init__.py +0 -21
  7. ultrashape/models/autoencoders/attention_blocks.py +0 -711
  8. ultrashape/models/autoencoders/attention_processors.py +0 -103
  9. ultrashape/models/autoencoders/model.py +0 -377
  10. ultrashape/models/autoencoders/surface_extractors.py +0 -266
  11. ultrashape/models/autoencoders/vae_trainer.py +0 -229
  12. ultrashape/models/autoencoders/volume_decoders.py +0 -440
  13. ultrashape/models/conditioner_mask.py +0 -337
  14. ultrashape/models/denoisers/__init__.py +0 -22
  15. ultrashape/models/denoisers/dit_mask.py +0 -725
  16. ultrashape/models/denoisers/moe_layers.py +0 -177
  17. ultrashape/models/diffusion/flow_matching_dit_trainer.py +0 -313
  18. ultrashape/models/diffusion/transport/__init__.py +0 -97
  19. ultrashape/models/diffusion/transport/integrators.py +0 -142
  20. ultrashape/models/diffusion/transport/path.py +0 -220
  21. ultrashape/models/diffusion/transport/transport.py +0 -534
  22. ultrashape/models/diffusion/transport/utils.py +0 -54
  23. ultrashape/pipelines.py +0 -797
  24. ultrashape/postprocessors.py +0 -209
  25. ultrashape/preprocessors.py +0 -167
  26. ultrashape/rembg.py +0 -32
  27. ultrashape/schedulers.py +0 -480
  28. ultrashape/surface_loaders.py +0 -233
  29. ultrashape/utils/__init__.py +0 -6
  30. ultrashape/utils/ema.py +0 -76
  31. ultrashape/utils/misc.py +0 -200
  32. ultrashape/utils/trainings/__init__.py +0 -1
  33. ultrashape/utils/trainings/callback.py +0 -213
  34. ultrashape/utils/trainings/lr_scheduler.py +0 -53
  35. ultrashape/utils/trainings/mesh.py +0 -128
  36. ultrashape/utils/trainings/mesh_log_callback.py +0 -342
  37. ultrashape/utils/trainings/peft.py +0 -78
  38. ultrashape/utils/typing.py +0 -41
  39. ultrashape/utils/utils.py +0 -128
  40. ultrashape/utils/visualizers/__init__.py +0 -1
  41. ultrashape/utils/visualizers/color_util.py +0 -57
  42. ultrashape/utils/visualizers/html_util.py +0 -64
  43. ultrashape/utils/visualizers/pythreejs_viewer.py +0 -549
  44. ultrashape/utils/voxelize.py +0 -74
ultrashape/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- from .pipelines import UltraShapePipeline
16
- from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
17
- from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/data/objaverse_dit.py DELETED
@@ -1,331 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # ==============================================================================
4
- # Original work Copyright (c) 2025 Tencent.
5
- # Modified work Copyright (c) 2025 UltraShape Team.
6
- #
7
- # Modified by UltraShape on 2025.12.25
8
- # ==============================================================================
9
-
10
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
11
- # except for the third-party components listed below.
12
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
13
- # in the repsective licenses of these third-party components.
14
- # Users must comply with all terms and conditions of original licenses of these third-party
15
- # components and must ensure that the usage of the third party components adheres to
16
- # all relevant laws and regulations.
17
-
18
- # For avoidance of doubts, Hunyuan 3D means the large language models and
19
- # their software and algorithms, including trained model weights, parameters (including
20
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
21
- # fine-tuning enabling code and other elements of the foregoing made publicly available
22
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
23
-
24
- import math
25
- import os
26
- import json
27
- from dataclasses import dataclass, field
28
-
29
- import random
30
- import imageio
31
- import numpy as np
32
- import pytorch_lightning as pl
33
- import torch
34
- import torch.nn.functional as F
35
- from torch.utils.data import DataLoader, Dataset
36
- from PIL import Image
37
- import pickle
38
- from ultrashape.utils.typing import *
39
- import pandas as pd
40
- import cv2
41
- import torchvision.transforms as transforms
42
- from pytorch_lightning.utilities import rank_zero_info
43
-
44
- def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
45
- """
46
- Pad the input image and mask to a square shape with padding ratio.
47
-
48
- Args:
49
- image (np.ndarray): Input image array of shape (H, W, C).
50
- mask (np.ndarray): Corresponding mask array of shape (H, W).
51
- center (bool): Whether to center the original image in the padded output.
52
- padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
53
-
54
- Returns:
55
- newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
56
- newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
57
- """
58
- h, w = image.shape[:2]
59
- max_side = max(h, w)
60
-
61
- # Select padding ratio either fixed or randomly within the given range
62
- if padding_ratio_range[0] == padding_ratio_range[1]:
63
- padding_ratio = padding_ratio_range[0]
64
- else:
65
- padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
66
- resize_side = int(max_side * padding_ratio)
67
-
68
- pad_h = resize_side - h
69
- pad_w = resize_side - w
70
- if center:
71
- start_h = pad_h // 2
72
- else:
73
- start_h = pad_h - resize_side // 20
74
-
75
- start_w = pad_w // 2
76
-
77
- # Create new white image and black mask with padded size
78
- newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
79
- newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
80
-
81
- # Place original image and mask into the padded canvas
82
- newimg[start_h:start_h + h, start_w:start_w + w] = image
83
- newmask[start_h:start_h + h, start_w:start_w + w] = mask
84
-
85
- return newimg, newmask
86
-
87
-
88
- class ObjaverseDataset(Dataset):
89
- def __init__(
90
- self,
91
- data_json,
92
- sample_root,
93
- image_path,
94
- image_transform = None,
95
- pc_size: int = 2048,
96
- pc_sharpedge_size: int = 2048,
97
- sharpedge_label: bool = False,
98
- return_normal: bool = False,
99
- padding = True,
100
- padding_ratio_range=[1.15, 1.15],
101
- ):
102
- super().__init__()
103
-
104
- self.uids = json.load(open(data_json))
105
- self.sample_root = sample_root
106
- self.image_paths = json.load(open(image_path))
107
- self.image_transform = image_transform
108
-
109
- self.pc_size = pc_size
110
- self.pc_sharpedge_size = pc_sharpedge_size
111
- self.sharpedge_label = sharpedge_label
112
- self.return_normal = return_normal
113
-
114
- self.padding = padding
115
- self.padding_ratio_range = padding_ratio_range
116
-
117
- print(f"Loaded {len(self.uids)} uids from {data_json}.")
118
-
119
- rank_zero_info(f'*' * 50)
120
- rank_zero_info(f'Dataset Infos:')
121
- rank_zero_info(f'# of 3D file: {len(self.uids)}')
122
- rank_zero_info(f'# of Surface Points: {self.pc_size}')
123
- rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
124
- rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
125
- rank_zero_info(f'*' * 50)
126
-
127
- def __len__(self):
128
- return len(self.uids)
129
-
130
- def _load_shape(self, index: int) -> Dict[str, Any]:
131
-
132
- data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
133
-
134
- surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
135
- normal = np.asarray(data['clean_surface_normals'])
136
- surface_og_n = np.concatenate([surface_og, normal], axis=1)
137
- rng = np.random.default_rng()
138
-
139
- # hard code: first 300k are uniform, last 300k are sharp
140
- assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
141
- coarse_surface = surface_og_n[:300000]
142
- sharp_surface = surface_og_n[300000:]
143
-
144
- surface_normal = []
145
- rng = np.random.default_rng()
146
- if self.pc_size > 0:
147
- ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
148
- coarse_surface = coarse_surface[ind]
149
- if self.sharpedge_label:
150
- sharpedge_label = np.zeros((self.pc_size // 2, 1))
151
- coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
152
- surface_normal.append(coarse_surface)
153
-
154
- ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
155
- sharp_surface = sharp_surface[ind_sharpedge]
156
- if self.sharpedge_label:
157
- sharpedge_label = np.ones((self.pc_size // 2, 1))
158
- sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
159
- surface_normal.append(sharp_surface)
160
-
161
- surface_normal = np.concatenate(surface_normal, axis=0)
162
- surface_normal = torch.FloatTensor(surface_normal)
163
- surface = surface_normal[:, 0:3]
164
- normal = surface_normal[:, 3:6]
165
- assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
166
-
167
- geo_points = 0.0
168
- normal = torch.nn.functional.normalize(normal, p=2, dim=1)
169
- if self.return_normal:
170
- surface = torch.cat([surface, normal], dim=-1)
171
- if self.sharpedge_label:
172
- surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
173
-
174
- ret = {
175
- "uid": self.uids[index],
176
- "surface": surface,
177
- "geo_points": geo_points
178
- }
179
- return ret
180
-
181
-
182
- def _load_image(self, index: int) -> Dict[str, Any]:
183
- ret = {}
184
- sel_idx = random.randint(0, 15)
185
- ret["sel_image_idx"] = sel_idx
186
- obj_name = self.uids[index]
187
- img_path = f'{self.image_paths[obj_name]}/{os.path.basename(self.image_paths[obj_name])}/rgba/' + f"{sel_idx:03d}.png"
188
-
189
- images, masks = [], []
190
- image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
191
- assert image.shape[2] == 4
192
- alpha = image[:, :, 3:4].astype(np.float32) / 255
193
- forground = image[:, :, :3]
194
- background = np.ones_like(forground) * 255
195
- img_new = forground * alpha + background * (1 - alpha)
196
- image = img_new.astype(np.uint8)
197
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
198
- mask = (alpha[:, :, 0] * 255).astype(np.uint8)
199
-
200
- if self.padding:
201
- h, w = image.shape[:2]
202
- binary = mask > 0.3
203
- non_zero_coords = np.argwhere(binary)
204
- x_min, y_min = non_zero_coords.min(axis=0)
205
- x_max, y_max = non_zero_coords.max(axis=0)
206
- image, mask = padding(
207
- image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
208
- mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
209
- center=True, padding_ratio_range=self.padding_ratio_range)
210
-
211
- if self.image_transform:
212
- image = self.image_transform(image)
213
- mask = np.stack((mask, mask, mask), axis=-1)
214
- mask = self.image_transform(mask)
215
-
216
- images.append(image)
217
- masks.append(mask)
218
- ret["image"] = torch.cat(images, dim=0)
219
- ret["mask"] = torch.cat(masks, dim=0)[:1, ...]
220
-
221
- return ret
222
-
223
- def get_data(self, index):
224
- ret = self._load_shape(index)
225
- ret.update(self._load_image(index))
226
- return ret
227
-
228
- def __getitem__(self, index):
229
- try:
230
- return self.get_data(index)
231
- except Exception as e:
232
- print(f"Error in {self.uids[index]}: {e}")
233
- return self.__getitem__(np.random.randint(len(self)))
234
-
235
- def collate(self, batch):
236
- batch = torch.utils.data.default_collate(batch)
237
- return batch
238
-
239
-
240
- class ObjaverseDataModule(pl.LightningDataModule):
241
- def __init__(
242
- self,
243
- batch_size: int = 1,
244
- num_workers: int = 4,
245
- val_num_workers: int = 2,
246
- training_data_list: str = None,
247
- sample_pcd_dir: str = None,
248
- image_data_json: str = None,
249
- image_size: int = 224,
250
- mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
251
- std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
252
- pc_size: int = 2048,
253
- pc_sharpedge_size: int = 2048,
254
- sharpedge_label: bool = False,
255
- return_normal: bool = False,
256
- padding = True,
257
- padding_ratio_range=[1.15, 1.15]
258
- ):
259
-
260
- super().__init__()
261
- self.batch_size = batch_size
262
- self.num_workers = num_workers
263
- self.val_num_workers = val_num_workers
264
-
265
- self.training_data_list = training_data_list
266
- self.sample_pcd_dir = sample_pcd_dir
267
- self.image_data_json = image_data_json
268
-
269
- self.image_size = image_size
270
- self.mean = mean
271
- self.std = std
272
- self.train_image_transform = transforms.Compose([
273
- transforms.ToTensor(),
274
- transforms.Resize(self.image_size),
275
- transforms.Normalize(mean=self.mean, std=self.std)])
276
- self.val_image_transform = transforms.Compose([
277
- transforms.ToTensor(),
278
- transforms.Resize(self.image_size),
279
- transforms.Normalize(mean=self.mean, std=self.std)])
280
-
281
- self.pc_size = pc_size
282
- self.pc_sharpedge_size = pc_sharpedge_size
283
- self.sharpedge_label = sharpedge_label
284
- self.return_normal = return_normal
285
-
286
- self.padding = padding
287
- self.padding_ratio_range = padding_ratio_range
288
-
289
- def train_dataloader(self):
290
- asl_params = {
291
- "data_json": f'{self.training_data_list}/train.json',
292
- "sample_root": self.sample_pcd_dir,
293
- "image_path": self.image_data_json,
294
- "image_transform": self.train_image_transform,
295
- "pc_size": self.pc_size,
296
- "pc_sharpedge_size": self.pc_sharpedge_size,
297
- "sharpedge_label": self.sharpedge_label,
298
- "return_normal": self.return_normal,
299
- "padding": self.padding,
300
- "padding_ratio_range": self.padding_ratio_range,
301
- }
302
- dataset = ObjaverseDataset(**asl_params)
303
- return torch.utils.data.DataLoader(
304
- dataset,
305
- batch_size=self.batch_size,
306
- num_workers=self.num_workers,
307
- pin_memory=True,
308
- drop_last=True,
309
- )
310
-
311
- def val_dataloader(self):
312
- asl_params = {
313
- "data_json": f'{self.training_data_list}/val.json',
314
- "sample_root": self.sample_pcd_dir,
315
- "image_path": self.image_data_json,
316
- "image_transform": self.val_image_transform,
317
- "pc_size": self.pc_size,
318
- "pc_sharpedge_size": self.pc_sharpedge_size,
319
- "sharpedge_label": self.sharpedge_label,
320
- "return_normal": self.return_normal,
321
- "padding": self.padding,
322
- "padding_ratio_range": self.padding_ratio_range,
323
- }
324
- dataset = ObjaverseDataset(**asl_params)
325
- return torch.utils.data.DataLoader(
326
- dataset,
327
- batch_size=self.batch_size,
328
- num_workers=self.val_num_workers,
329
- pin_memory=True,
330
- drop_last=True,
331
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/data/objaverse_vae.py DELETED
@@ -1,262 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # ==============================================================================
4
- # Original work Copyright (c) 2025 Tencent.
5
- # Modified work Copyright (c) 2025 UltraShape Team.
6
- #
7
- # Modified by UltraShape on 2025.12.25
8
- # ==============================================================================
9
-
10
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
11
- # except for the third-party components listed below.
12
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
13
- # in the repsective licenses of these third-party components.
14
- # Users must comply with all terms and conditions of original licenses of these third-party
15
- # components and must ensure that the usage of the third party components adheres to
16
- # all relevant laws and regulations.
17
-
18
- # For avoidance of doubts, Hunyuan 3D means the large language models and
19
- # their software and algorithms, including trained model weights, parameters (including
20
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
21
- # fine-tuning enabling code and other elements of the foregoing made publicly available
22
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
23
-
24
-
25
- import os
26
- import cv2
27
- import json
28
- import math
29
- import random
30
- import imageio
31
- import pickle
32
- import numpy as np
33
- from PIL import Image
34
- import pandas as pd
35
- from dataclasses import dataclass, field
36
-
37
- import torch
38
- import torch.nn.functional as F
39
- import pytorch_lightning as pl
40
- from torch.utils.data import DataLoader, Dataset
41
- import torchvision.transforms as transforms
42
- from pytorch_lightning.utilities import rank_zero_info
43
- from ultrashape.utils.typing import *
44
-
45
- class ObjaverseDataset(Dataset):
46
- def __init__(
47
- self,
48
- data_json,
49
- sample_root,
50
- pc_size: int = 2048,
51
- pc_sharpedge_size: int = 2048,
52
- sup_near_uni_size: int = 4096,
53
- sup_near_sharp_size: int = 4096,
54
- sup_space_size: int = 4096,
55
- tsdf_threshold: float = 0.05,
56
- sharpedge_label: bool = False,
57
- return_normal: bool = False,
58
- ):
59
- super().__init__()
60
-
61
- self.uids = json.load(open(data_json))
62
- self.sample_root = sample_root
63
-
64
- self.pc_size = pc_size
65
- self.pc_sharpedge_size = pc_sharpedge_size
66
- self.sharpedge_label = sharpedge_label
67
- self.return_normal = return_normal
68
-
69
- self.sup_near_uni_size = sup_near_uni_size
70
- self.sup_near_sharp_size = sup_near_sharp_size
71
- self.sup_space_size = sup_space_size
72
- self.tsdf_threshold = tsdf_threshold
73
-
74
- print(f"Loaded {len(self.uids)} uids from {data_json}.")
75
-
76
- rank_zero_info(f'*' * 50)
77
- rank_zero_info(f'Dataset Infos:')
78
- rank_zero_info(f'# of 3D file: {len(self.uids)}')
79
- rank_zero_info(f'# of Surface Points: {self.pc_size}')
80
- rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
81
- rank_zero_info(f'# of Uniform Near-Surface Sup-Points: {self.sup_near_uni_size}')
82
- rank_zero_info(f'# of Sharpedge Near-Surface Sup-Points: {self.sup_near_sharp_size}')
83
- rank_zero_info(f'# of Random Space Sup-Points: {self.sup_space_size}')
84
- rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
85
- rank_zero_info(f'*' * 50)
86
-
87
- def __len__(self):
88
- return len(self.uids)
89
-
90
- def _load_shape(self, index: int) -> Dict[str, Any]:
91
- rng = np.random.default_rng()
92
-
93
- data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
94
-
95
- ##################### sup pcd&sdf ######################
96
- uniform_near_points = (np.asarray(data['uniform_near_points'])-0.5) * 2
97
- curvature_near_points = (np.asarray(data['curvature_near_points'])-0.5) * 2
98
- space_points = (np.asarray(data['space_points'])-0.5) * 2
99
- uniform_near_sdf = np.asarray(data['uniform_near_sdf']) * 2
100
- curvature_near_sdf = np.asarray(data['curvature_near_sdf']) * 2
101
- space_sdf = np.asarray(data['space_sdf']) * 2
102
-
103
- uni_noisy_idx = rng.choice(uniform_near_points.shape[0], self.sup_near_uni_size, replace=False)
104
- cur_noisy_idx = rng.choice(curvature_near_points.shape[0], self.sup_near_sharp_size, replace=False)
105
- space_idx = rng.choice(space_points.shape[0], self.sup_space_size, replace=False)
106
-
107
- uniform_near_points = uniform_near_points[uni_noisy_idx]
108
- curvature_near_points = curvature_near_points[cur_noisy_idx]
109
- space_points = space_points[space_idx]
110
- uniform_near_sdf = uniform_near_sdf[uni_noisy_idx]
111
- curvature_near_sdf = curvature_near_sdf[cur_noisy_idx]
112
- space_sdf = space_sdf[space_idx]
113
-
114
- uniform_near_sdf, curvature_near_sdf, space_sdf = map(self._clip_to_tsdf, (uniform_near_sdf, curvature_near_sdf, space_sdf))
115
-
116
- surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
117
- normal = np.asarray(data['clean_surface_normals'])
118
- surface_og_n = np.concatenate([surface_og, normal], axis=1)
119
- rng = np.random.default_rng()
120
-
121
- # hard code: first 300k are uniform, last 300k are sharp
122
- assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
123
- coarse_surface = surface_og_n[:300000]
124
- sharp_surface = surface_og_n[300000:]
125
-
126
- surface_normal = []
127
-
128
- if self.pc_size > 0:
129
- ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
130
- coarse_surface = coarse_surface[ind]
131
- if self.sharpedge_label:
132
- sharpedge_label = np.zeros((self.pc_size // 2, 1))
133
- coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
134
- surface_normal.append(coarse_surface)
135
-
136
- ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
137
- sharp_surface = sharp_surface[ind_sharpedge]
138
- if self.sharpedge_label:
139
- sharpedge_label = np.ones((self.pc_size // 2, 1))
140
- sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
141
- surface_normal.append(sharp_surface)
142
-
143
- surface_normal = np.concatenate(surface_normal, axis=0)
144
- surface_normal = torch.FloatTensor(surface_normal)
145
- surface = surface_normal[:, 0:3]
146
- normal = surface_normal[:, 3:6]
147
- assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
148
-
149
- geo_points = 0.0
150
- normal = torch.nn.functional.normalize(normal, p=2, dim=1)
151
- if self.return_normal:
152
- surface = torch.cat([surface, normal], dim=-1)
153
- if self.sharpedge_label:
154
- surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
155
-
156
- ret = {
157
- "uid": self.uids[index],
158
- "surface": surface,
159
- "sup_near_uniform": np.concatenate([uniform_near_points, uniform_near_sdf[...,None]], axis=1),
160
- "sup_near_sharp": np.concatenate([curvature_near_points, curvature_near_sdf[...,None]], axis=1),
161
- "sup_space": np.concatenate([space_points, space_sdf[...,None]], axis=1),
162
- "geo_points": geo_points
163
- }
164
- return ret
165
-
166
- def _clip_to_tsdf(self, sdf: np.array):
167
- nan_mask = np.isnan(sdf)
168
- if np.any(nan_mask):
169
- sdf=np.nan_to_num(sdf, nan=1.0, posinf=1.0, neginf=-1.0)
170
- return sdf.flatten().astype(np.float32).clip(-self.tsdf_threshold, self.tsdf_threshold) / self.tsdf_threshold
171
-
172
- def get_data(self, index):
173
- ret = self._load_shape(index)
174
- return ret
175
-
176
- def __getitem__(self, index):
177
- return self.get_data(index)
178
-
179
- def collate(self, batch):
180
- batch = torch.utils.data.default_collate(batch)
181
- return batch
182
-
183
-
184
- class ObjaverseDataModule(pl.LightningDataModule):
185
- def __init__(
186
- self,
187
- batch_size: int = 1,
188
- num_workers: int = 4,
189
- val_num_workers: int = 2,
190
- training_data_list: str = None,
191
- sample_pcd_dir: str = None,
192
- pc_size: int = 2048,
193
- pc_sharpedge_size: int = 2048,
194
- sup_near_uni_size: int = 4096,
195
- sup_near_sharp_size: int = 4096,
196
- sup_space_size: int = 4096,
197
- tsdf_threshold: float = 0.05,
198
- sharpedge_label: bool = False,
199
- return_normal: bool = False,
200
- ):
201
-
202
- super().__init__()
203
- self.batch_size = batch_size
204
- self.num_workers = num_workers
205
- self.val_num_workers = val_num_workers
206
-
207
- self.training_data_list = training_data_list
208
- self.sample_pcd_dir = sample_pcd_dir
209
-
210
- self.pc_size = pc_size
211
- self.pc_sharpedge_size = pc_sharpedge_size
212
- self.sharpedge_label = sharpedge_label
213
- self.return_normal = return_normal
214
-
215
- self.sup_near_uni_size = sup_near_uni_size
216
- self.sup_near_sharp_size = sup_near_sharp_size
217
- self.sup_space_size = sup_space_size
218
- self.tsdf_threshold = tsdf_threshold
219
-
220
- def train_dataloader(self):
221
- asl_params = {
222
- "data_json": f'{self.training_data_list}/train.json',
223
- "sample_root": self.sample_pcd_dir,
224
- "pc_size": self.pc_size,
225
- "pc_sharpedge_size": self.pc_sharpedge_size,
226
- "sup_near_uni_size": self.sup_near_uni_size,
227
- "sup_near_sharp_size": self.sup_near_sharp_size,
228
- "sup_space_size": self.sup_space_size,
229
- "tsdf_threshold": self.tsdf_threshold,
230
- "sharpedge_label": self.sharpedge_label,
231
- "return_normal": self.return_normal,
232
- }
233
- dataset = ObjaverseDataset(**asl_params)
234
- return torch.utils.data.DataLoader(
235
- dataset,
236
- batch_size=self.batch_size,
237
- num_workers=self.num_workers,
238
- pin_memory=True,
239
- drop_last=True,
240
- )
241
-
242
- def val_dataloader(self):
243
- asl_params = {
244
- "data_json": f'{self.training_data_list}/val.json',
245
- "sample_root": self.sample_pcd_dir,
246
- "pc_size": self.pc_size,
247
- "pc_sharpedge_size": self.pc_sharpedge_size,
248
- "sup_near_uni_size": self.sup_near_uni_size,
249
- "sup_near_sharp_size": self.sup_near_sharp_size,
250
- "sup_space_size": self.sup_space_size,
251
- "tsdf_threshold": self.tsdf_threshold,
252
- "sharpedge_label": self.sharpedge_label,
253
- "return_normal": self.return_normal,
254
- }
255
- dataset = ObjaverseDataset(**asl_params)
256
- return torch.utils.data.DataLoader(
257
- dataset,
258
- batch_size=self.batch_size,
259
- num_workers=self.val_num_workers,
260
- pin_memory=True,
261
- drop_last=True,
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/data/utils.py DELETED
@@ -1,193 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # ==============================================================================
4
- # Original work Copyright (c) 2025 Tencent.
5
- # Modified work Copyright (c) 2025 UltraShape Team.
6
- #
7
- # Modified by UltraShape on 2025.12.25
8
- # ==============================================================================
9
-
10
- # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
11
- # This file is part of the WebDataset library.
12
- # See the LICENSE file for licensing terms (BSD-style).
13
-
14
-
15
- """Miscellaneous utility functions."""
16
-
17
- import importlib
18
- import itertools as itt
19
- import os
20
- import re
21
- import sys
22
- from typing import Any, Callable, Iterator, Union
23
- import torch
24
- import numpy as np
25
-
26
-
27
- def make_seed(*args):
28
- seed = 0
29
- for arg in args:
30
- seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
31
- return seed
32
-
33
-
34
- class PipelineStage:
35
- def invoke(self, *args, **kw):
36
- raise NotImplementedError
37
-
38
-
39
- def identity(x: Any) -> Any:
40
- """Return the argument as is."""
41
- return x
42
-
43
-
44
- def safe_eval(s: str, expr: str = "{}"):
45
- """Evaluate the given expression more safely."""
46
- if re.sub("[^A-Za-z0-9_]", "", s) != s:
47
- raise ValueError(f"safe_eval: illegal characters in: '{s}'")
48
- return eval(expr.format(s))
49
-
50
-
51
- def lookup_sym(sym: str, modules: list):
52
- """Look up a symbol in a list of modules."""
53
- for mname in modules:
54
- module = importlib.import_module(mname, package="webdataset")
55
- result = getattr(module, sym, None)
56
- if result is not None:
57
- return result
58
- return None
59
-
60
-
61
- def repeatedly0(
62
- loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
63
- ):
64
- """Repeatedly returns batches from a DataLoader."""
65
- for _ in range(nepochs):
66
- yield from itt.islice(loader, nbatches)
67
-
68
-
69
- def guess_batchsize(batch: Union[tuple, list]):
70
- """Guess the batch size by looking at the length of the first element in a tuple."""
71
- return len(batch[0])
72
-
73
-
74
- def repeatedly(
75
- source: Iterator,
76
- nepochs: int = None,
77
- nbatches: int = None,
78
- nsamples: int = None,
79
- batchsize: Callable[..., int] = guess_batchsize,
80
- ):
81
- """Repeatedly yield samples from an iterator."""
82
- epoch = 0
83
- batch = 0
84
- total = 0
85
- while True:
86
- for sample in source:
87
- yield sample
88
- batch += 1
89
- if nbatches is not None and batch >= nbatches:
90
- return
91
- if nsamples is not None:
92
- total += guess_batchsize(sample)
93
- if total >= nsamples:
94
- return
95
- epoch += 1
96
- if nepochs is not None and epoch >= nepochs:
97
- return
98
-
99
-
100
- def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
101
- """Return node and worker info for PyTorch and some distributed environments."""
102
- rank = 0
103
- world_size = 1
104
- worker = 0
105
- num_workers = 1
106
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
107
- rank = int(os.environ["RANK"])
108
- world_size = int(os.environ["WORLD_SIZE"])
109
- else:
110
- try:
111
- import torch.distributed
112
-
113
- if torch.distributed.is_available() and torch.distributed.is_initialized():
114
- group = group or torch.distributed.group.WORLD
115
- rank = torch.distributed.get_rank(group=group)
116
- world_size = torch.distributed.get_world_size(group=group)
117
- except ModuleNotFoundError:
118
- pass
119
- if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
120
- worker = int(os.environ["WORKER"])
121
- num_workers = int(os.environ["NUM_WORKERS"])
122
- else:
123
- try:
124
- import torch.utils.data
125
-
126
- worker_info = torch.utils.data.get_worker_info()
127
- if worker_info is not None:
128
- worker = worker_info.id
129
- num_workers = worker_info.num_workers
130
- except ModuleNotFoundError:
131
- pass
132
-
133
- return rank, world_size, worker, num_workers
134
-
135
-
136
- def pytorch_worker_seed(group=None):
137
- """Compute a distinct, deterministic RNG seed for each worker and node."""
138
- rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
139
- return rank * 1000 + worker
140
-
141
- def worker_init_fn(_):
142
- worker_info = torch.utils.data.get_worker_info()
143
- worker_id = worker_info.id
144
-
145
- # dataset = worker_info.dataset
146
- # split_size = dataset.num_records // worker_info.num_workers
147
- # # reset num_records to the true number to retain reliable length information
148
- # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
149
- # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
150
- # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
151
-
152
- return np.random.seed(np.random.get_state()[1][0] + worker_id)
153
-
154
-
155
- def collation_fn(samples, combine_tensors=True, combine_scalars=True):
156
- """
157
-
158
- Args:
159
- samples (list[dict]):
160
- combine_tensors:
161
- combine_scalars:
162
-
163
- Returns:
164
-
165
- """
166
-
167
- result = {}
168
-
169
- keys = samples[0].keys()
170
-
171
- for key in keys:
172
- result[key] = []
173
-
174
- for sample in samples:
175
- for key in keys:
176
- val = sample[key]
177
- result[key].append(val)
178
-
179
- for key in keys:
180
- val_list = result[key]
181
- if isinstance(val_list[0], (int, float)):
182
- if combine_scalars:
183
- result[key] = np.array(result[key])
184
-
185
- elif isinstance(val_list[0], torch.Tensor):
186
- if combine_tensors:
187
- result[key] = torch.stack(val_list)
188
-
189
- elif isinstance(val_list[0], np.ndarray):
190
- if combine_tensors:
191
- result[key] = np.stack(val_list)
192
-
193
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- # Open Source Model Licensed under the Apache License Version 2.0
2
- # and Other Licenses of the Third-Party Components therein:
3
- # The below Model in this distribution may have been modified by THL A29 Limited
4
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
-
6
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
- # The below software and/or models in this distribution may have been
8
- # modified by THL A29 Limited ("Tencent Modifications").
9
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
-
11
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
- # except for the third-party components listed below.
13
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
- # in the repsective licenses of these third-party components.
15
- # Users must comply with all terms and conditions of original licenses of these third-party
16
- # components and must ensure that the usage of the third party components adheres to
17
- # all relevant laws and regulations.
18
-
19
- # For avoidance of doubts, Hunyuan 3D means the large language models and
20
- # their software and algorithms, including trained model weights, parameters (including
21
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
- # fine-tuning enabling code and other elements of the foregoing made publicly available
23
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
-
25
- from .autoencoders import ShapeVAE
26
- from .conditioner_mask import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
27
- from .denoisers import RefineDiT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- from .attention_blocks import CrossAttentionDecoder
16
- from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
17
- FlashVDMTopMCrossAttentionProcessor
18
- from .model import ShapeVAE, VectsetVAE
19
- from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
20
- from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
21
- from .vae_trainer import VAETrainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/attention_blocks.py DELETED
@@ -1,711 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Open Source Model Licensed under the Apache License Version 2.0
9
- # and Other Licenses of the Third-Party Components therein:
10
- # The below Model in this distribution may have been modified by THL A29 Limited
11
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
12
-
13
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
14
- # The below software and/or models in this distribution may have been
15
- # modified by THL A29 Limited ("Tencent Modifications").
16
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
17
-
18
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
19
- # except for the third-party components listed below.
20
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
21
- # in the repsective licenses of these third-party components.
22
- # Users must comply with all terms and conditions of original licenses of these third-party
23
- # components and must ensure that the usage of the third party components adheres to
24
- # all relevant laws and regulations.
25
-
26
- # For avoidance of doubts, Hunyuan 3D means the large language models and
27
- # their software and algorithms, including trained model weights, parameters (including
28
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
29
- # fine-tuning enabling code and other elements of the foregoing made publicly available
30
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
31
-
32
-
33
- import os
34
- from typing import Optional, Union, List
35
-
36
- import torch
37
- import torch.nn as nn
38
- from einops import rearrange
39
- from torch import Tensor
40
-
41
- from .attention_processors import CrossAttentionProcessor
42
- from ...utils import logger
43
- from ultrashape.utils import voxelize_from_point
44
-
45
- scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
46
-
47
- if os.environ.get('USE_SAGEATTN', '0') == '1':
48
- try:
49
- from sageattention import sageattn
50
- except ImportError:
51
- raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
52
- scaled_dot_product_attention = sageattn
53
-
54
-
55
- class FourierEmbedder(nn.Module):
56
- """ The sin/cosine positional embedding. """
57
-
58
- def __init__(self,
59
- num_freqs: int = 6,
60
- logspace: bool = True,
61
- input_dim: int = 3,
62
- include_input: bool = True,
63
- include_pi: bool = True) -> None:
64
-
65
- super().__init__()
66
-
67
- if logspace:
68
- frequencies = 2.0 ** torch.arange(
69
- num_freqs,
70
- dtype=torch.float32
71
- )
72
- else:
73
- frequencies = torch.linspace(
74
- 1.0,
75
- 2.0 ** (num_freqs - 1),
76
- num_freqs,
77
- dtype=torch.float32
78
- )
79
-
80
- if include_pi:
81
- frequencies *= torch.pi
82
-
83
- self.register_buffer("frequencies", frequencies, persistent=False)
84
- self.include_input = include_input
85
- self.num_freqs = num_freqs
86
-
87
- self.out_dim = self.get_dims(input_dim)
88
-
89
- def get_dims(self, input_dim):
90
- temp = 1 if self.include_input or self.num_freqs == 0 else 0
91
- out_dim = input_dim * (self.num_freqs * 2 + temp)
92
-
93
- return out_dim
94
-
95
- def forward(self, x: torch.Tensor) -> torch.Tensor:
96
- """ Forward process.
97
-
98
- Args:
99
- x: tensor of shape [..., dim]
100
-
101
- Returns:
102
- embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
103
- where temp is 1 if include_input is True and 0 otherwise.
104
- """
105
-
106
- if self.num_freqs > 0:
107
- embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
108
- if self.include_input:
109
- return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
110
- else:
111
- return torch.cat((embed.sin(), embed.cos()), dim=-1)
112
- else:
113
- return x
114
-
115
-
116
- class DropPath(nn.Module):
117
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
118
- """
119
-
120
- def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
121
- super(DropPath, self).__init__()
122
- self.drop_prob = drop_prob
123
- self.scale_by_keep = scale_by_keep
124
-
125
- def forward(self, x):
126
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
127
-
128
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
129
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
130
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
131
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
132
- 'survival rate' as the argument.
133
-
134
- """
135
- if self.drop_prob == 0. or not self.training:
136
- return x
137
- keep_prob = 1 - self.drop_prob
138
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
139
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
140
- if keep_prob > 0.0 and self.scale_by_keep:
141
- random_tensor.div_(keep_prob)
142
- return x * random_tensor
143
-
144
- def extra_repr(self):
145
- return f'drop_prob={round(self.drop_prob, 3):0.3f}'
146
-
147
-
148
- class MLP(nn.Module):
149
- def __init__(
150
- self, *,
151
- width: int,
152
- expand_ratio: int = 4,
153
- output_width: int = None,
154
- drop_path_rate: float = 0.0
155
- ):
156
- super().__init__()
157
- self.width = width
158
- self.c_fc = nn.Linear(width, width * expand_ratio)
159
- self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
160
- self.gelu = nn.GELU()
161
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
162
-
163
- def forward(self, x):
164
- return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
165
-
166
-
167
- class QKVMultiheadCrossAttention(nn.Module):
168
- def __init__(
169
- self,
170
- *,
171
- heads: int,
172
- n_data: Optional[int] = None,
173
- width=None,
174
- qk_norm=False,
175
- norm_layer=nn.LayerNorm
176
- ):
177
- super().__init__()
178
- self.heads = heads
179
- self.n_data = n_data
180
- self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
181
- self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
182
-
183
- self.attn_processor = CrossAttentionProcessor()
184
-
185
- def forward(self, q, kv):
186
- _, n_ctx, _ = q.shape
187
- bs, n_data, width = kv.shape
188
- attn_ch = width // self.heads // 2
189
- q = q.view(bs, n_ctx, self.heads, -1)
190
- kv = kv.view(bs, n_data, self.heads, -1)
191
- k, v = torch.split(kv, attn_ch, dim=-1)
192
-
193
- q = self.q_norm(q)
194
- k = self.k_norm(k)
195
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
196
- out = self.attn_processor(self, q, k, v)
197
- out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
198
- return out
199
-
200
-
201
- class MultiheadCrossAttention(nn.Module):
202
- def __init__(
203
- self,
204
- *,
205
- width: int,
206
- heads: int,
207
- qkv_bias: bool = True,
208
- n_data: Optional[int] = None,
209
- data_width: Optional[int] = None,
210
- norm_layer=nn.LayerNorm,
211
- qk_norm: bool = False,
212
- kv_cache: bool = False,
213
- ):
214
- super().__init__()
215
- self.n_data = n_data
216
- self.width = width
217
- self.heads = heads
218
- self.data_width = width if data_width is None else data_width
219
- self.c_q = nn.Linear(width, width, bias=qkv_bias)
220
- self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
221
- self.c_proj = nn.Linear(width, width)
222
- self.attention = QKVMultiheadCrossAttention(
223
- heads=heads,
224
- n_data=n_data,
225
- width=width,
226
- norm_layer=norm_layer,
227
- qk_norm=qk_norm
228
- )
229
- self.kv_cache = kv_cache
230
- self.data = None
231
-
232
- def forward(self, x, data):
233
- x = self.c_q(x)
234
- if self.kv_cache:
235
- if self.data is None:
236
- self.data = self.c_kv(data)
237
- logger.info('Save kv cache,this should be called only once for one mesh')
238
- data = self.data
239
- else:
240
- data = self.c_kv(data)
241
- x = self.attention(x, data)
242
- x = self.c_proj(x)
243
- return x
244
-
245
-
246
- class ResidualCrossAttentionBlock(nn.Module):
247
- def __init__(
248
- self,
249
- *,
250
- n_data: Optional[int] = None,
251
- width: int,
252
- heads: int,
253
- mlp_expand_ratio: int = 4,
254
- data_width: Optional[int] = None,
255
- qkv_bias: bool = True,
256
- norm_layer=nn.LayerNorm,
257
- qk_norm: bool = False
258
- ):
259
- super().__init__()
260
-
261
- if data_width is None:
262
- data_width = width
263
-
264
- self.attn = MultiheadCrossAttention(
265
- n_data=n_data,
266
- width=width,
267
- heads=heads,
268
- data_width=data_width,
269
- qkv_bias=qkv_bias,
270
- norm_layer=norm_layer,
271
- qk_norm=qk_norm
272
- )
273
- self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
274
- self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
275
- self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
276
- self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
277
-
278
- def forward(self, x: torch.Tensor, data: torch.Tensor):
279
- x = x + self.attn(self.ln_1(x), self.ln_2(data))
280
- x = x + self.mlp(self.ln_3(x))
281
- return x
282
-
283
-
284
- class QKVMultiheadAttention(nn.Module):
285
- def __init__(
286
- self,
287
- *,
288
- heads: int,
289
- n_ctx: int,
290
- width=None,
291
- qk_norm=False,
292
- norm_layer=nn.LayerNorm
293
- ):
294
- super().__init__()
295
- self.heads = heads
296
- self.n_ctx = n_ctx
297
- self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
298
- self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
299
-
300
- def forward(self, qkv):
301
- bs, n_ctx, width = qkv.shape
302
- attn_ch = width // self.heads // 3
303
- qkv = qkv.view(bs, n_ctx, self.heads, -1)
304
- q, k, v = torch.split(qkv, attn_ch, dim=-1)
305
-
306
- q = self.q_norm(q)
307
- k = self.k_norm(k)
308
-
309
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
310
- out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
311
- return out
312
-
313
-
314
- class MultiheadAttention(nn.Module):
315
- def __init__(
316
- self,
317
- *,
318
- n_ctx: int,
319
- width: int,
320
- heads: int,
321
- qkv_bias: bool,
322
- norm_layer=nn.LayerNorm,
323
- qk_norm: bool = False,
324
- drop_path_rate: float = 0.0
325
- ):
326
- super().__init__()
327
- self.n_ctx = n_ctx
328
- self.width = width
329
- self.heads = heads
330
- self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
331
- self.c_proj = nn.Linear(width, width)
332
- self.attention = QKVMultiheadAttention(
333
- heads=heads,
334
- n_ctx=n_ctx,
335
- width=width,
336
- norm_layer=norm_layer,
337
- qk_norm=qk_norm
338
- )
339
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
340
-
341
- def forward(self, x):
342
- x = self.c_qkv(x)
343
- x = self.attention(x)
344
- x = self.drop_path(self.c_proj(x))
345
- return x
346
-
347
-
348
- class ResidualAttentionBlock(nn.Module):
349
- def __init__(
350
- self,
351
- *,
352
- n_ctx: int,
353
- width: int,
354
- heads: int,
355
- qkv_bias: bool = True,
356
- norm_layer=nn.LayerNorm,
357
- qk_norm: bool = False,
358
- drop_path_rate: float = 0.0,
359
- ):
360
- super().__init__()
361
- self.attn = MultiheadAttention(
362
- n_ctx=n_ctx,
363
- width=width,
364
- heads=heads,
365
- qkv_bias=qkv_bias,
366
- norm_layer=norm_layer,
367
- qk_norm=qk_norm,
368
- drop_path_rate=drop_path_rate
369
- )
370
- self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
371
- self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
372
- self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
373
-
374
- def forward(self, x: torch.Tensor):
375
- x = x + self.attn(self.ln_1(x))
376
- x = x + self.mlp(self.ln_2(x))
377
- return x
378
-
379
-
380
- class Transformer(nn.Module):
381
- def __init__(
382
- self,
383
- *,
384
- n_ctx: int,
385
- width: int,
386
- layers: int,
387
- heads: int,
388
- qkv_bias: bool = True,
389
- norm_layer=nn.LayerNorm,
390
- qk_norm: bool = False,
391
- drop_path_rate: float = 0.0
392
- ):
393
- super().__init__()
394
- self.n_ctx = n_ctx
395
- self.width = width
396
- self.layers = layers
397
- self.resblocks = nn.ModuleList(
398
- [
399
- ResidualAttentionBlock(
400
- n_ctx=n_ctx,
401
- width=width,
402
- heads=heads,
403
- qkv_bias=qkv_bias,
404
- norm_layer=norm_layer,
405
- qk_norm=qk_norm,
406
- drop_path_rate=drop_path_rate
407
- )
408
- for _ in range(layers)
409
- ]
410
- )
411
-
412
- def forward(self, x: torch.Tensor):
413
- for block in self.resblocks:
414
- x = block(x)
415
- return x
416
-
417
-
418
- class CrossAttentionDecoder(nn.Module):
419
-
420
- def __init__(
421
- self,
422
- *,
423
- num_latents: int,
424
- out_channels: int,
425
- fourier_embedder: FourierEmbedder,
426
- width: int,
427
- heads: int,
428
- mlp_expand_ratio: int = 4,
429
- downsample_ratio: int = 1,
430
- enable_ln_post: bool = True,
431
- qkv_bias: bool = True,
432
- qk_norm: bool = False,
433
- label_type: str = "binary"
434
- ):
435
- super().__init__()
436
-
437
- self.enable_ln_post = enable_ln_post
438
- self.fourier_embedder = fourier_embedder
439
- self.downsample_ratio = downsample_ratio
440
- self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
441
- if self.downsample_ratio != 1:
442
- self.latents_proj = nn.Linear(width * downsample_ratio, width)
443
- if self.enable_ln_post == False:
444
- qk_norm = False
445
- self.cross_attn_decoder = ResidualCrossAttentionBlock(
446
- n_data=num_latents,
447
- width=width,
448
- mlp_expand_ratio=mlp_expand_ratio,
449
- heads=heads,
450
- qkv_bias=qkv_bias,
451
- qk_norm=qk_norm
452
- )
453
-
454
- if self.enable_ln_post:
455
- self.ln_post = nn.LayerNorm(width)
456
- self.output_proj = nn.Linear(width, out_channels)
457
- self.label_type = label_type
458
- self.count = 0
459
-
460
- def set_cross_attention_processor(self, processor):
461
- self.cross_attn_decoder.attn.attention.attn_processor = processor
462
-
463
- def set_default_cross_attention_processor(self):
464
- self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
465
-
466
- def forward(self, queries=None, query_embeddings=None, latents=None):
467
- if query_embeddings is None:
468
- query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
469
- self.count += query_embeddings.shape[1]
470
- if self.downsample_ratio != 1:
471
- latents = self.latents_proj(latents)
472
- x = self.cross_attn_decoder(query_embeddings, latents)
473
- if self.enable_ln_post:
474
- x = self.ln_post(x)
475
- occ = self.output_proj(x)
476
- return occ
477
-
478
-
479
- def fps(
480
- src: torch.Tensor,
481
- batch: Optional[Tensor] = None,
482
- ratio: Optional[Union[Tensor, float]] = None,
483
- random_start: bool = True,
484
- batch_size: Optional[int] = None,
485
- ptr: Optional[Union[Tensor, List[int]]] = None,
486
- ):
487
- src = src.float()
488
- from torch_cluster import fps as fps_fn
489
- output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
490
- return output
491
-
492
-
493
- class PointCrossAttentionEncoder(nn.Module):
494
-
495
- def __init__(
496
- self, *,
497
- num_latents: int,
498
- downsample_ratio: float,
499
- pc_size: int,
500
- pc_sharpedge_size: int,
501
- fourier_embedder: FourierEmbedder,
502
- point_feats: int,
503
- width: int,
504
- heads: int,
505
- layers: int,
506
- voxel_query_res: int,
507
- normal_pe: bool = False,
508
- qkv_bias: bool = True,
509
- use_ln_post: bool = False,
510
- use_checkpoint: bool = False,
511
- qk_norm: bool = False,
512
- jitter_query: bool = False,
513
- voxel_query: bool = False,
514
- ):
515
-
516
- super().__init__()
517
-
518
- self.use_checkpoint = use_checkpoint
519
- self.num_latents = num_latents
520
- self.downsample_ratio = downsample_ratio
521
- self.point_feats = point_feats
522
- self.normal_pe = normal_pe
523
- self.jitter_query = jitter_query
524
- self.voxel_query = voxel_query
525
- self.voxel_query_res = voxel_query_res
526
-
527
- if pc_sharpedge_size == 0:
528
- print(
529
- f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is zero')
530
- else:
531
- print(
532
- f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
533
-
534
- self.pc_size = pc_size
535
- self.pc_sharpedge_size = pc_sharpedge_size
536
-
537
- self.fourier_embedder = fourier_embedder
538
-
539
- if self.jitter_query or self.voxel_query:
540
- self.input_proj_q = nn.Linear(self.fourier_embedder.out_dim, width)
541
- self.input_proj_kv = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
542
- else:
543
- self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
544
- self.cross_attn = ResidualCrossAttentionBlock(
545
- width=width,
546
- heads=heads,
547
- qkv_bias=qkv_bias,
548
- qk_norm=qk_norm
549
- )
550
-
551
- self.self_attn = None
552
- if layers > 0:
553
- self.self_attn = Transformer(
554
- n_ctx=num_latents,
555
- width=width,
556
- layers=layers,
557
- heads=heads,
558
- qkv_bias=qkv_bias,
559
- qk_norm=qk_norm
560
- )
561
-
562
- if use_ln_post:
563
- self.ln_post = nn.LayerNorm(width)
564
- else:
565
- self.ln_post = None
566
-
567
- def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
568
- B, N, D = pc.shape
569
- num_pts = self.num_latents * self.downsample_ratio
570
-
571
- # Compute number of latents
572
- num_latents = int(num_pts / self.downsample_ratio)
573
-
574
- # Compute the number of random and sharpedge latents
575
- num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
576
- num_sharpedge_query = num_latents - num_random_query
577
-
578
- # Split random and sharpedge surface points
579
- random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)
580
- assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
581
- assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
582
-
583
- # Randomly select random surface points and random query points
584
- input_random_pc_size = int(num_random_query * self.downsample_ratio)
585
- random_query_ratio = num_random_query / input_random_pc_size
586
- idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]
587
- input_random_pc = random_pc[:, idx_random_pc, :]
588
-
589
- if self.voxel_query:
590
- query_random_pc, query_voxel_indices = voxelize_from_point(pc, num_latents, resolution=self.voxel_query_res)
591
- else:
592
- flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
593
- N_down = int(flatten_input_random_pc.shape[0] / B)
594
- batch_down = torch.arange(B).to(pc.device)
595
- batch_down = torch.repeat_interleave(batch_down, N_down)
596
- idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)
597
- query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
598
-
599
- # Randomly select sharpedge surface points and sharpedge query points
600
- input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
601
- if input_sharpedge_pc_size == 0 or self.voxel_query:
602
- input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)
603
- query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)
604
- else:
605
- sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
606
- idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[
607
- :input_sharpedge_pc_size]
608
- input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
609
- flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)
610
- N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
611
- batch_down = torch.arange(B).to(pc.device)
612
- batch_down = torch.repeat_interleave(batch_down, N_down)
613
- idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)
614
- query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)
615
-
616
- # Concatenate random and sharpedge surface points and query points
617
- query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
618
- input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
619
-
620
- if self.jitter_query:
621
- R = self.voxel_query_res // 2
622
- noise = torch.rand_like(query_pc)
623
- query_pc += (noise - 0.5) / R
624
-
625
- # PE
626
- query = self.fourier_embedder(query_pc)
627
- data = self.fourier_embedder(input_pc)
628
-
629
- # Concat normal if given
630
- if self.point_feats != 0:
631
-
632
- random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],
633
- dim=1)
634
- input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
635
- if not self.voxel_query and not self.jitter_query:
636
- flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)
637
- query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,
638
- flatten_input_random_surface_feats.shape[
639
- -1])
640
-
641
- if input_sharpedge_pc_size == 0:
642
- input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,
643
- dtype=input_random_surface_feats.dtype).to(pc.device)
644
- if not self.voxel_query and not self.jitter_query:
645
- query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(
646
- pc.device)
647
- else:
648
- input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]
649
- if not self.voxel_query and not self.jitter_query:
650
- flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,
651
- -1)
652
- query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,
653
- flatten_input_sharpedge_surface_feats.shape[
654
- -1])
655
- if not self.voxel_query and not self.jitter_query:
656
- query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
657
- input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)
658
-
659
- if self.normal_pe:
660
- if not self.voxel_query and not self.jitter_query:
661
- query_normal_pe = self.fourier_embedder(query_feats[..., :3])
662
- query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
663
- input_normal_pe = self.fourier_embedder(input_feats[..., :3])
664
- input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
665
-
666
- if not self.voxel_query and not self.jitter_query:
667
- query = torch.cat([query, query_feats], dim=-1)
668
- data = torch.cat([data, input_feats], dim=-1)
669
-
670
- if input_sharpedge_pc_size == 0:
671
- query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
672
- input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
673
-
674
- if self.voxel_query:
675
- pc_infos = [query_voxel_indices, query_random_pc]
676
- else:
677
- pc_infos = [query_pc, input_pc, query_random_pc, input_random_pc, query_sharpedge_pc, input_sharpedge_pc]
678
- return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), pc_infos
679
-
680
-
681
- def forward(self, pc, feats):
682
- """
683
-
684
- Args:
685
- pc (torch.FloatTensor): [B, N, 3]
686
- feats (torch.FloatTensor or None): [B, N, C]
687
-
688
- Returns:
689
-
690
- """
691
- query, data, pc_infos = self.sample_points_and_latents(pc, feats)
692
-
693
- if self.jitter_query or self.voxel_query:
694
- query = self.input_proj_q(query)
695
- query = query
696
- data = self.input_proj_kv(data)
697
- data = data
698
- else:
699
- query = self.input_proj(query)
700
- query = query
701
- data = self.input_proj(data)
702
- data = data
703
-
704
- latents = self.cross_attn(query, data)
705
- if self.self_attn is not None:
706
- latents = self.self_attn(latents)
707
-
708
- if self.ln_post is not None:
709
- latents = self.ln_post(latents)
710
-
711
- return latents, pc_infos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/attention_processors.py DELETED
@@ -1,103 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- import os
23
-
24
- import torch
25
- import torch.nn.functional as F
26
-
27
- scaled_dot_product_attention = F.scaled_dot_product_attention
28
- if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
29
- try:
30
- from sageattention import sageattn
31
- except ImportError:
32
- raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
33
- scaled_dot_product_attention = sageattn
34
-
35
-
36
- class CrossAttentionProcessor:
37
- def __call__(self, attn, q, k, v):
38
- out = scaled_dot_product_attention(q, k, v)
39
- return out
40
-
41
-
42
- class FlashVDMCrossAttentionProcessor:
43
- def __init__(self, topk=None):
44
- self.topk = topk
45
-
46
- def __call__(self, attn, q, k, v):
47
- if k.shape[-2] == 3072:
48
- topk = 1024
49
- elif k.shape[-2] == 512:
50
- topk = 256
51
- else:
52
- topk = k.shape[-2] // 3
53
-
54
- if self.topk is True:
55
- q1 = q[:, :, ::100, :]
56
- sim = q1 @ k.transpose(-1, -2)
57
- sim = torch.mean(sim, -2)
58
- topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
59
- topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
60
- v0 = torch.gather(v, dim=-2, index=topk_ind)
61
- k0 = torch.gather(k, dim=-2, index=topk_ind)
62
- out = scaled_dot_product_attention(q, k0, v0)
63
- elif self.topk is False:
64
- out = scaled_dot_product_attention(q, k, v)
65
- else:
66
- idx, counts = self.topk
67
- start = 0
68
- outs = []
69
- for grid_coord, count in zip(idx, counts):
70
- end = start + count
71
- q_chunk = q[:, :, start:end, :]
72
- k0, v0 = self.select_topkv(q_chunk, k, v, topk)
73
- out = scaled_dot_product_attention(q_chunk, k0, v0)
74
- outs.append(out)
75
- start += count
76
- out = torch.cat(outs, dim=-2)
77
- self.topk = False
78
- return out
79
-
80
- def select_topkv(self, q_chunk, k, v, topk):
81
- q1 = q_chunk[:, :, ::50, :]
82
- sim = q1 @ k.transpose(-1, -2)
83
- sim = torch.mean(sim, -2)
84
- topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
85
- topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
86
- v0 = torch.gather(v, dim=-2, index=topk_ind)
87
- k0 = torch.gather(k, dim=-2, index=topk_ind)
88
- return k0, v0
89
-
90
-
91
- class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
92
- def select_topkv(self, q_chunk, k, v, topk):
93
- q1 = q_chunk[:, :, ::30, :]
94
- sim = q1 @ k.transpose(-1, -2)
95
- # sim = sim.to(torch.float32)
96
- sim = sim.softmax(-1)
97
- sim = torch.mean(sim, 1)
98
- activated_token = torch.where(sim > 1e-6)[2]
99
- index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
100
- index = index.expand(-1, v.shape[1], -1, v.shape[-1])
101
- v0 = torch.gather(v, dim=-2, index=index)
102
- k0 = torch.gather(k, dim=-2, index=index)
103
- return k0, v0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/model.py DELETED
@@ -1,377 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Open Source Model Licensed under the Apache License Version 2.0
9
- # and Other Licenses of the Third-Party Components therein:
10
- # The below Model in this distribution may have been modified by THL A29 Limited
11
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
12
-
13
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
14
- # The below software and/or models in this distribution may have been
15
- # modified by THL A29 Limited ("Tencent Modifications").
16
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
17
-
18
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
19
- # except for the third-party components listed below.
20
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
21
- # in the repsective licenses of these third-party components.
22
- # Users must comply with all terms and conditions of original licenses of these third-party
23
- # components and must ensure that the usage of the third party components adheres to
24
- # all relevant laws and regulations.
25
-
26
- # For avoidance of doubts, Hunyuan 3D means the large language models and
27
- # their software and algorithms, including trained model weights, parameters (including
28
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
29
- # fine-tuning enabling code and other elements of the foregoing made publicly available
30
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
31
-
32
- import os
33
- from typing import Union, List
34
-
35
- import numpy as np
36
- import torch
37
- import torch.nn as nn
38
- import yaml
39
-
40
- from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder
41
- from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
42
- from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
43
- from ...utils import logger, synchronize_timer, smart_load_model
44
-
45
-
46
- class DiagonalGaussianDistribution(object):
47
- def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
48
- """
49
- Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
50
-
51
- Args:
52
- parameters (Union[torch.Tensor, List[torch.Tensor]]):
53
- Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
54
- or a list of two tensors [mean, logvar].
55
- deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
56
- Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
57
- concatenated if parameters is a single tensor. Default is 1.
58
- """
59
- self.feat_dim = feat_dim
60
- self.parameters = parameters
61
-
62
- if isinstance(parameters, list):
63
- self.mean = parameters[0]
64
- self.logvar = parameters[1]
65
- else:
66
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
67
-
68
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
69
- self.deterministic = deterministic
70
- self.std = torch.exp(0.5 * self.logvar)
71
- self.var = torch.exp(self.logvar)
72
- if self.deterministic:
73
- self.var = self.std = torch.zeros_like(self.mean)
74
-
75
- def sample(self):
76
- """
77
- Sample from the diagonal Gaussian distribution.
78
-
79
- Returns:
80
- torch.Tensor: A sample tensor with the same shape as the mean.
81
- """
82
- x = self.mean + self.std * torch.randn_like(self.mean)
83
- return x
84
-
85
- def kl(self, other=None, dims=(1, 2)):
86
- """
87
- Compute the Kullback-Leibler (KL) divergence between this distribution and another.
88
-
89
- If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
90
-
91
- Args:
92
- other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
93
- dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
94
- Default is (1, 2, 3).
95
-
96
- Returns:
97
- torch.Tensor: The mean KL divergence value.
98
- """
99
- if self.deterministic:
100
- return torch.Tensor([0.])
101
- else:
102
- if other is None:
103
- return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
104
- else:
105
- return 0.5 * torch.mean(
106
- torch.pow(self.mean - other.mean, 2) / other.var
107
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
108
- dim=dims)
109
-
110
- def nll(self, sample, dims=(1, 2, 3)):
111
- if self.deterministic:
112
- return torch.Tensor([0.])
113
- logtwopi = np.log(2.0 * np.pi)
114
- return 0.5 * torch.sum(
115
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
116
- dim=dims)
117
-
118
- def mode(self):
119
- return self.mean
120
-
121
-
122
- class VectsetVAE(nn.Module):
123
-
124
- @classmethod
125
- @synchronize_timer('VectsetVAE Model Loading')
126
- def from_single_file(
127
- cls,
128
- ckpt_path,
129
- config_path=None,
130
- params=None,
131
- device='cuda',
132
- dtype=torch.float16,
133
- use_safetensors=None,
134
- **kwargs,
135
- ):
136
- # load config
137
- with open(config_path, 'r') as f:
138
- config = yaml.safe_load(f)
139
-
140
- # load ckpt
141
- if use_safetensors:
142
- ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
143
- if not os.path.exists(ckpt_path):
144
- raise FileNotFoundError(f"Model file {ckpt_path} not found")
145
-
146
- logger.info(f"Loading model from {ckpt_path}")
147
- if use_safetensors:
148
- import safetensors.torch
149
- ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
150
- else:
151
- ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
152
-
153
- if params is not None:
154
- model_kwargs = params
155
- else:
156
- model_kwargs = config['params']
157
- model_kwargs.update(kwargs)
158
-
159
- model = cls(**model_kwargs)
160
- model.load_state_dict(ckpt)
161
-
162
- model.to(device=device, dtype=dtype)
163
- return model
164
-
165
- @classmethod
166
- def from_pretrained(
167
- cls,
168
- model_path,
169
- device='cuda',
170
- params=None,
171
- dtype=torch.float16,
172
- use_safetensors=False,
173
- variant='fp16',
174
- subfolder='hunyuan3d-vae-v2-1',
175
- **kwargs,
176
- ):
177
- config_path, ckpt_path = smart_load_model(
178
- model_path,
179
- subfolder=subfolder,
180
- use_safetensors=use_safetensors,
181
- variant=variant
182
- )
183
-
184
- return cls.from_single_file(
185
- ckpt_path,
186
- config_path=config_path,
187
- params=params,
188
- device=device,
189
- dtype=dtype,
190
- use_safetensors=use_safetensors,
191
- **kwargs
192
- )
193
-
194
- def init_from_ckpt(self, path, ignore_keys=()):
195
- state_dict = torch.load(path, map_location="cpu")
196
- state_dict = state_dict.get("state_dict", state_dict)
197
- keys = list(state_dict.keys())
198
- for k in keys:
199
- for ik in ignore_keys:
200
- if k.startswith(ik):
201
- print("Deleting key {} from state_dict.".format(k))
202
- del state_dict[k]
203
- missing, unexpected = self.load_state_dict(state_dict, strict=False)
204
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
205
- if len(missing) > 0:
206
- print(f"Missing Keys: {missing}")
207
- print(f"Unexpected Keys: {unexpected}")
208
-
209
- def __init__(
210
- self,
211
- volume_decoder=None,
212
- surface_extractor=None
213
- ):
214
- super().__init__()
215
- if volume_decoder is None:
216
- volume_decoder = VanillaVolumeDecoder()
217
- if surface_extractor is None:
218
- surface_extractor = MCSurfaceExtractor()
219
- self.volume_decoder = volume_decoder
220
- self.surface_extractor = surface_extractor
221
-
222
- def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
223
- with synchronize_timer('Volume decoding'):
224
- grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
225
- with synchronize_timer('Surface extraction'):
226
- outputs = self.surface_extractor(grid_logits, **kwargs)
227
- return outputs, grid_logits
228
-
229
- def enable_flashvdm_decoder(
230
- self,
231
- enabled: bool = True,
232
- adaptive_kv_selection=True,
233
- topk_mode='mean',
234
- mc_algo='mc',
235
- ):
236
- if enabled:
237
- if adaptive_kv_selection:
238
- self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
239
- else:
240
- self.volume_decoder = HierarchicalVolumeDecoding()
241
- if mc_algo not in SurfaceExtractors.keys():
242
- raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')
243
- self.surface_extractor = SurfaceExtractors[mc_algo]()
244
- else:
245
- self.volume_decoder = VanillaVolumeDecoder()
246
- self.surface_extractor = MCSurfaceExtractor()
247
-
248
-
249
- class ShapeVAE(VectsetVAE):
250
- def __init__(
251
- self,
252
- *,
253
- num_latents: int,
254
- embed_dim: int,
255
- width: int,
256
- heads: int,
257
- num_decoder_layers: int,
258
- num_encoder_layers: int = 8,
259
- pc_size: int = 5120,
260
- pc_sharpedge_size: int = 5120,
261
- point_feats: int = 3,
262
- downsample_ratio: int = 20,
263
- geo_decoder_downsample_ratio: int = 1,
264
- geo_decoder_mlp_expand_ratio: int = 4,
265
- geo_decoder_ln_post: bool = True,
266
- num_freqs: int = 8,
267
- include_pi: bool = True,
268
- qkv_bias: bool = True,
269
- qk_norm: bool = False,
270
- label_type: str = "binary",
271
- drop_path_rate: float = 0.0,
272
- scale_factor: float = 1.0,
273
- use_ln_post: bool = True,
274
- enable_flashvdm: bool = False,
275
- ckpt_path = None,
276
- jitter_query: bool = False,
277
- voxel_query: bool = False,
278
- voxel_query_res: int = 128,
279
- ):
280
- super().__init__()
281
- self.geo_decoder_ln_post = geo_decoder_ln_post
282
- self.downsample_ratio = downsample_ratio
283
-
284
- self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
285
-
286
- self.encoder = PointCrossAttentionEncoder(
287
- fourier_embedder=self.fourier_embedder,
288
- num_latents=num_latents,
289
- downsample_ratio=self.downsample_ratio,
290
- pc_size=pc_size,
291
- pc_sharpedge_size=pc_sharpedge_size,
292
- point_feats=point_feats,
293
- width=width,
294
- heads=heads,
295
- layers=num_encoder_layers,
296
- qkv_bias=qkv_bias,
297
- use_ln_post=use_ln_post,
298
- qk_norm=qk_norm,
299
- jitter_query=jitter_query,
300
- voxel_query=voxel_query,
301
- voxel_query_res=voxel_query_res
302
- )
303
-
304
- self.pre_kl = nn.Linear(width, embed_dim * 2)
305
- self.post_kl = nn.Linear(embed_dim, width)
306
-
307
- self.transformer = Transformer(
308
- n_ctx=num_latents,
309
- width=width,
310
- layers=num_decoder_layers,
311
- heads=heads,
312
- qkv_bias=qkv_bias,
313
- qk_norm=qk_norm,
314
- drop_path_rate=drop_path_rate
315
- )
316
-
317
- self.geo_decoder = CrossAttentionDecoder(
318
- fourier_embedder=self.fourier_embedder,
319
- out_channels=1,
320
- num_latents=num_latents,
321
- mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
322
- downsample_ratio=geo_decoder_downsample_ratio,
323
- enable_ln_post=self.geo_decoder_ln_post,
324
- width=width // geo_decoder_downsample_ratio,
325
- heads=heads // geo_decoder_downsample_ratio,
326
- qkv_bias=qkv_bias,
327
- qk_norm=qk_norm,
328
- label_type=label_type,
329
- )
330
-
331
- self.scale_factor = scale_factor
332
- self.latent_shape = (num_latents, embed_dim)
333
-
334
- if ckpt_path is not None:
335
- self.init_from_ckpt(ckpt_path)
336
-
337
- if enable_flashvdm:
338
- self.enable_flashvdm_decoder()
339
-
340
- def forward(self, latents):
341
- latents = self.post_kl(latents)
342
- latents = self.transformer(latents)
343
- return latents
344
-
345
- def encode(self, surface, sample_posterior=True, need_kl=False, need_voxel=False):
346
- pc, feats = surface[:, :, :3], surface[:, :, 3:]
347
- latents, pc_infos = self.encoder(pc, feats)
348
- # print(latents.shape, self.pre_kl.weight.shape)
349
- moments = self.pre_kl(latents)
350
- posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
351
- if sample_posterior:
352
- latents = posterior.sample()
353
- else:
354
- latents = posterior.mode()
355
- if need_kl:
356
- return latents, posterior
357
- if need_voxel:
358
- return latents, pc_infos[0]
359
- return latents
360
-
361
- def decode(self, latents, voxel_idx=None):
362
- latents = self.post_kl(latents)
363
- latents = self.transformer(latents)
364
- return latents
365
-
366
- def query(self, latents, queries, voxel_idx=None):
367
- """
368
- Args:
369
- queries (torch.FloatTensor): [B, N, 3]
370
- latents (torch.FloatTensor): [B, embed_dim]
371
-
372
- Returns:
373
- logits (torch.FloatTensor): [B, N], occupancy logits
374
- """
375
- logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)
376
-
377
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/surface_extractors.py DELETED
@@ -1,266 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- from typing import Union, Tuple, List
23
-
24
- import numpy as np
25
- import torch
26
- from skimage import measure
27
- import cubvh
28
-
29
-
30
- class Latent2MeshOutput:
31
- def __init__(self, mesh_v=None, mesh_f=None):
32
- self.mesh_v = mesh_v
33
- self.mesh_f = mesh_f
34
-
35
-
36
- def center_vertices(vertices):
37
- """Translate the vertices so that bounding box is centered at zero."""
38
- vert_min = vertices.min(dim=0)[0]
39
- vert_max = vertices.max(dim=0)[0]
40
- vert_center = 0.5 * (vert_min + vert_max)
41
- return vertices - vert_center
42
-
43
-
44
- class SurfaceExtractor:
45
- def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
46
- """
47
- Compute grid size, bounding box minimum coordinates, and bounding box size based on input
48
- bounds and resolution.
49
-
50
- Args:
51
- bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
52
- float representing half side length.
53
- If float, bounds are assumed symmetric around zero in all axes.
54
- Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
55
- octree_resolution (int): Resolution of the octree grid.
56
-
57
- Returns:
58
- grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
59
- bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
60
- bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
61
- """
62
- if isinstance(bounds, float):
63
- bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
64
-
65
- bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
66
- bbox_size = bbox_max - bbox_min
67
- grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
68
- return grid_size, bbox_min, bbox_size
69
-
70
- def run(self, *args, **kwargs):
71
- """
72
- Abstract method to extract surface mesh from grid logits.
73
-
74
- This method should be implemented by subclasses.
75
-
76
- Raises:
77
- NotImplementedError: Always, since this is an abstract method.
78
- """
79
- return NotImplementedError
80
-
81
- def __call__(self, grid_logits, **kwargs):
82
- """
83
- Process a batch of grid logits to extract surface meshes.
84
-
85
- Args:
86
- grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
87
- **kwargs: Additional keyword arguments passed to the `run` method.
88
-
89
- Returns:
90
- List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
91
- If extraction fails for a grid, None is appended at that position.
92
- """
93
- outputs = []
94
- for i in range(grid_logits.shape[0]):
95
- try:
96
- vertices, faces = self.run(grid_logits[i], **kwargs)
97
- vertices = vertices.astype(np.float32)
98
- faces = np.ascontiguousarray(faces)
99
- outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
100
-
101
- except Exception:
102
- import traceback
103
- traceback.print_exc()
104
- outputs.append(None)
105
-
106
- return outputs
107
-
108
-
109
- def get_sparse_valid_voxels(grid_logit: torch.Tensor):
110
-
111
- if not isinstance(grid_logit, torch.Tensor):
112
- raise TypeError("Input must be a PyTorch tensor.")
113
- if grid_logit.dim() != 3 or grid_logit.shape[0] != grid_logit.shape[1] or grid_logit.shape[0] != grid_logit.shape[2]:
114
- raise ValueError("Input tensor must have shape (N, N, N)")
115
-
116
- N = grid_logit.shape[0]
117
- device = grid_logit.device
118
-
119
- # Chunk processing to save memory
120
- chunk_size = 128
121
-
122
- all_sparse_coords = []
123
- all_sparse_logits = []
124
-
125
- # Process in chunks along x-axis
126
- for start_x in range(0, N - 1, chunk_size):
127
- end_x = min(start_x + chunk_size, N - 1)
128
-
129
- # Determine slice range including +1 for neighbor checks
130
- # slice_end needs to be end_x + 1 to include the neighbors for the last voxel in chunk
131
- slice_end = end_x + 1
132
-
133
- chunk = grid_logit[start_x:slice_end, :, :]
134
- nan_mask = torch.isnan(chunk)
135
-
136
- # Compute mask for this chunk (valid voxels are 0 to end_x - start_x)
137
- # Note: chunk shape is [D_chunk, N, N].
138
- # We want to check validity for [0..D_chunk-1, :-1, :-1]
139
-
140
- sub_nan_mask = nan_mask
141
-
142
- # Validity check requires looking at i and i+1
143
- # Invalid if ANY corner is NaN
144
- invalid_voxel_mask = (
145
- sub_nan_mask[:-1, :-1, :-1] |
146
- sub_nan_mask[1:, :-1, :-1] |
147
- sub_nan_mask[:-1, 1:, :-1] |
148
- sub_nan_mask[:-1, :-1, 1:] |
149
- sub_nan_mask[:-1, 1:, 1:] |
150
- sub_nan_mask[1:, :-1, 1:] |
151
- sub_nan_mask[1:, 1:, :-1] |
152
- sub_nan_mask[1:, 1:, 1:]
153
- )
154
-
155
- valid_voxel_mask = ~invalid_voxel_mask
156
-
157
- # Get local coordinates
158
- local_coords = valid_voxel_mask.nonzero(as_tuple=False)
159
-
160
- if local_coords.shape[0] > 0:
161
- lx, ly, lz = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]
162
-
163
- # Extract logits using local indices on the chunk
164
- # v0 is at lx, v1 is at lx+1, etc.
165
- sparse_vertex_logits = torch.stack([
166
- chunk[lx, ly, lz], # v0
167
- chunk[lx + 1, ly, lz], # v1
168
- chunk[lx + 1, ly + 1, lz], # v2
169
- chunk[lx, ly + 1, lz], # v3
170
- chunk[lx, ly, lz + 1], # v4
171
- chunk[lx + 1, ly, lz + 1], # v5
172
- chunk[lx + 1, ly + 1, lz + 1], # v6
173
- chunk[lx, ly + 1, lz + 1] # v7
174
- ], dim=1)
175
-
176
- # Convert local coords to global coords
177
- # x coordinate needs offset added
178
- global_coords = local_coords.clone()
179
- global_coords[:, 0] += start_x
180
-
181
- all_sparse_coords.append(global_coords)
182
- all_sparse_logits.append(sparse_vertex_logits)
183
-
184
- # Free memory
185
- del chunk, nan_mask, invalid_voxel_mask, valid_voxel_mask, local_coords
186
-
187
- if not all_sparse_coords:
188
- return torch.empty((0, 3), dtype=torch.long, device=device), torch.empty((0, 8), dtype=grid_logit.dtype, device=device)
189
-
190
- sparse_coords = torch.cat(all_sparse_coords, dim=0)
191
- sparse_vertex_logits = torch.cat(all_sparse_logits, dim=0)
192
-
193
- return sparse_coords, sparse_vertex_logits
194
-
195
-
196
- class MCSurfaceExtractor(SurfaceExtractor):
197
- def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
198
- """
199
- Extract surface mesh using the Marching Cubes algorithm.
200
-
201
- Args:
202
- grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
203
- mc_level (float): The level (iso-value) at which to extract the surface.
204
- bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
205
- octree_resolution (int): Resolution of the octree grid.
206
- **kwargs: Additional keyword arguments (ignored).
207
-
208
- Returns:
209
- Tuple[np.ndarray, np.ndarray]: Tuple containing:
210
- - vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
211
- box coordinates.
212
- - faces (np.ndarray): Extracted mesh faces (triangles).
213
- """
214
-
215
- grid_logit = grid_logit.detach()
216
-
217
- sparse_coords, sparse_logits = get_sparse_valid_voxels(grid_logit)
218
- # Convert to float32 only for the sparse set
219
- vertices, faces = cubvh.sparse_marching_cubes(sparse_coords, sparse_logits.float(), mc_level)
220
-
221
- vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
222
- # vertices, faces, normals, _ = measure.marching_cubes(grid_logit,
223
- # mc_level, method="lewiner", mask=(~np.isnan(grid_logit)))
224
- grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
225
- vertices = vertices / grid_size * bbox_size + bbox_min
226
- return vertices, faces
227
-
228
-
229
- class DMCSurfaceExtractor(SurfaceExtractor):
230
- def run(self, grid_logit, *, octree_resolution, **kwargs):
231
- """
232
- Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
233
-
234
- Args:
235
- grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
236
- octree_resolution (int): Resolution of the octree grid.
237
- **kwargs: Additional keyword arguments (ignored).
238
-
239
- Returns:
240
- Tuple[np.ndarray, np.ndarray]: Tuple containing:
241
- - vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
242
- - faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
243
-
244
- Raises:
245
- ImportError: If the 'diso' package is not installed.
246
- """
247
- device = grid_logit.device
248
- if not hasattr(self, 'dmc'):
249
- try:
250
- from diso import DiffDMC
251
- self.dmc = DiffDMC(dtype=torch.float32).to(device)
252
- except:
253
- raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
254
- sdf = -grid_logit / octree_resolution
255
- sdf = sdf.to(torch.float32).contiguous()
256
- verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
257
- verts = center_vertices(verts)
258
- vertices = verts.detach().cpu().numpy()
259
- faces = faces.detach().cpu().numpy()[:, ::-1]
260
- return vertices, faces
261
-
262
-
263
- SurfaceExtractors = {
264
- 'mc': MCSurfaceExtractor,
265
- 'dmc': DMCSurfaceExtractor,
266
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/vae_trainer.py DELETED
@@ -1,229 +0,0 @@
1
- import os
2
- from contextlib import contextmanager
3
- from typing import List, Tuple, Optional, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torch.optim import lr_scheduler
8
- import pytorch_lightning as pl
9
- from pytorch_lightning.utilities import rank_zero_info
10
- from pytorch_lightning.utilities import rank_zero_only
11
- import trimesh
12
-
13
- from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model
14
-
15
-
16
- def export_to_trimesh(mesh_output):
17
- if isinstance(mesh_output, list):
18
- outputs = []
19
- for mesh in mesh_output:
20
- if mesh is None:
21
- outputs.append(None)
22
- else:
23
- mesh.mesh_f = mesh.mesh_f[:, ::-1]
24
- mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
25
- outputs.append(mesh_output)
26
- return outputs
27
- else:
28
- mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
29
- mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
30
- return mesh_output
31
-
32
- class VAETrainer(pl.LightningModule):
33
- def __init__(
34
- self,
35
- *,
36
- vae_config,
37
- optimizer_cfg,
38
- loss_cfg,
39
- save_dir,
40
- mc_res,
41
- ckpt_path: Optional[str] = None,
42
- ignore_keys: Union[Tuple[str], List[str]] = (),
43
- torch_compile: bool = False,
44
- ):
45
- super().__init__()
46
-
47
- # ========= init optimizer config ========= #
48
- self.optimizer_cfg = optimizer_cfg
49
- self.loss_cfg = loss_cfg
50
- self.ckpt_path = ckpt_path
51
- self.vae_model = instantiate_vae_model(vae_config, requires_grad=True)
52
- if ckpt_path is not None:
53
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
54
-
55
- self.mc_res = mc_res
56
- self.save_root = save_dir
57
- if not os.path.exists(save_dir):
58
- os.makedirs(save_dir)
59
-
60
- # ========= torch compile to accelerate ========= #
61
- self.torch_compile = torch_compile
62
- if self.torch_compile:
63
- torch.nn.Module.compile(self.vae_model)
64
- print(f'*' * 100)
65
- print(f'Compile model for acceleration')
66
- print(f'*' * 100)
67
-
68
- def init_from_ckpt(self, path, ignore_keys=()):
69
- ckpt = torch.load(path, map_location="cpu")
70
- if 'state_dict' not in ckpt:
71
- # deepspeed ckpt
72
- state_dict = {}
73
- for k in ckpt.keys():
74
- new_k = k.replace('_forward_module.', '')
75
- state_dict[new_k] = ckpt[k]
76
- else:
77
- state_dict = ckpt["state_dict"]
78
-
79
- keys = list(state_dict.keys())
80
- for k in keys:
81
- for ik in ignore_keys:
82
- if ik in k:
83
- print("Deleting key {} from state_dict.".format(k))
84
- del state_dict[k]
85
-
86
- # # ==================== Weight Surgery Start ====================
87
- # old_key_base = "vae_model.encoder.input_proj"
88
- # old_weight_key = f"{old_key_base}.weight"
89
- # old_bias_key = f"{old_key_base}.bias"
90
-
91
- # if old_weight_key in state_dict:
92
- # print(f"[*] Detected legacy '{old_key_base}' in checkpoint. Performing weight surgery...")
93
-
94
- # src_weight = state_dict[old_weight_key]
95
- # src_bias = state_dict[old_bias_key]
96
-
97
- # encoder = self.vae_model.encoder
98
- # fourier_dim = encoder.fourier_embedder.out_dim
99
-
100
- # # --- A. input_proj_kv ---
101
- # # shape: [width, fourier_dim + point_feats]
102
- # encoder.input_proj_kv.weight.data.copy_(src_weight)
103
- # encoder.input_proj_kv.bias.data.copy_(src_bias)
104
- # print(f" -> Loaded input_proj_kv from {old_key_base}")
105
-
106
- # # --- B. input_proj_q ---
107
- # # shape: [width, fourier_dim]
108
- # sliced_weight = src_weight[:, :fourier_dim]
109
- # encoder.input_proj_q.weight.data.copy_(sliced_weight)
110
- # encoder.input_proj_q.bias.data.copy_(src_bias)
111
- # print(f" -> Loaded input_proj_q (sliced) from {old_key_base}")
112
-
113
- # del state_dict[old_weight_key]
114
- # if old_bias_key in state_dict:
115
- # del state_dict[old_bias_key]
116
- # # ==================== Weight Surgery End ====================
117
-
118
- missing, unexpected = self.load_state_dict(state_dict, strict=False)
119
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
120
- if len(missing) > 0:
121
- print(f"Missing Keys: {missing}")
122
- print(f"Unexpected Keys: {unexpected}")
123
-
124
-
125
- def configure_optimizers(self) -> Tuple[List, List]:
126
- lr = self.learning_rate
127
-
128
- params_list = []
129
- trainable_parameters = list(self.vae_model.parameters())
130
- params_list.append({'params': trainable_parameters, 'lr': lr})
131
-
132
- optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
133
- if hasattr(self.optimizer_cfg, 'scheduler'):
134
- scheduler_func = instantiate_from_config(
135
- self.optimizer_cfg.scheduler,
136
- max_decay_steps=self.trainer.max_steps,
137
- lr_max=lr
138
- )
139
- scheduler = {
140
- "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
141
- "interval": "step",
142
- "frequency": 1
143
- }
144
- schedulers = [scheduler]
145
- else:
146
- schedulers = []
147
- optimizers = [optimizer]
148
-
149
- return optimizers, schedulers
150
-
151
- def on_train_epoch_start(self) -> None:
152
- pl.seed_everything(self.trainer.global_rank)
153
-
154
- def forward(self, batch):
155
- sup_pc_s_list = [batch["sup_near_uniform"], batch["sup_near_sharp"], batch["sup_space"]]
156
- rand_points = [sup_pc_s[:,:,:3] for sup_pc_s in sup_pc_s_list]
157
- rand_points_val = [sup_pc_s[:,:,3:] for sup_pc_s in sup_pc_s_list]
158
-
159
- rand_points = torch.cat(rand_points, dim=1)
160
- target = torch.cat(rand_points_val, dim=1)[...,0]
161
- target = -target
162
-
163
- latents, posterior = self.vae_model.encode(
164
- batch['surface'], sample_posterior=True, need_kl=True)
165
- latents = self.vae_model.decode(latents)
166
- logits = self.vae_model.query(latents, rand_points)
167
-
168
- loss_kl = posterior.kl()
169
- loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
170
-
171
- criteria = torch.nn.MSELoss()
172
- criteria2 = torch.nn.L1Loss()
173
- loss_logits = criteria(logits, target).mean() + criteria2(logits, target).mean()
174
- loss = self.loss_cfg.lambda_logits * loss_logits + self.loss_cfg.lambda_kl * loss_kl
175
-
176
- loss_dict = {
177
- "loss": loss,
178
- "loss_logits": loss_logits,
179
- "loss_kl": loss_kl
180
- }
181
- return loss_dict, latents
182
-
183
- def training_step(self, batch, batch_idx, optimizer_idx=0):
184
- loss, latents = self.forward(batch)
185
- split = 'train'
186
- loss_dict = {
187
- f"{split}/total_loss": loss["loss"].detach(),
188
- f"{split}/loss_logits": loss["loss_logits"].detach(),
189
- f"{split}/loss_kl": loss["loss_kl"].detach(),
190
- f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
191
- }
192
- self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
193
-
194
- return loss
195
-
196
- def validation_step(self, batch, batch_idx, optimizer_idx=0):
197
- loss, latents = self.forward(batch)
198
- split = 'val'
199
- loss_dict = {
200
- f"{split}/total_loss": loss["loss"].detach(),
201
- f"{split}/loss_logits": loss["loss_logits"].detach(),
202
- f"{split}/loss_kl": loss["loss_kl"].detach(),
203
- f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
204
- }
205
- self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
206
- if self.trainer.global_rank < 2:
207
- with torch.no_grad():
208
- save_dir = f"{self.save_root}/gs{self.global_step:010d}_rank{self.trainer.global_rank}"
209
- if not os.path.exists(save_dir):
210
- os.makedirs(save_dir)
211
- uids = batch.get('uid')
212
- for i, latent in enumerate(latents[:5]):
213
- mesh, grid_logits = self.vae_model.latents2mesh(
214
- latent[None],
215
- output_type='trimesh',
216
- bounds=1.01,
217
- mc_level=0.0,
218
- num_chunks=20000,
219
- octree_resolution=self.mc_res,
220
- mc_algo='mc',
221
- enable_pbar=True
222
- )
223
-
224
- mesh = export_to_trimesh(mesh[0])
225
-
226
- save_path = f"{save_dir}/recon_{os.path.splitext(os.path.basename(uids[i]))[0]}_mc{self.mc_res}.obj"
227
- mesh.export(save_path)
228
-
229
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/autoencoders/volume_decoders.py DELETED
@@ -1,440 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- from typing import Union, Tuple, List, Callable
23
-
24
- import numpy as np
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- from einops import repeat
29
- from tqdm import tqdm
30
-
31
- from .attention_blocks import CrossAttentionDecoder
32
- from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
33
- from ...utils import logger
34
-
35
-
36
- def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
37
- val = input_tensor + alpha
38
- valid_mask = val > -9000
39
-
40
- mask = torch.ones_like(val, dtype=torch.int32)
41
- sign = torch.sign(val.to(torch.float32))
42
-
43
- # Helper to compute neighbor for a single direction
44
- def check_neighbor_sign(shift, axis):
45
- if shift == 0:
46
- return
47
-
48
- pad_dims = [0, 0, 0, 0, 0, 0]
49
- if axis == 0:
50
- pad_idx = 0 if shift > 0 else 1
51
- pad_dims[pad_idx] = abs(shift)
52
- elif axis == 1:
53
- pad_idx = 2 if shift > 0 else 3
54
- pad_dims[pad_idx] = abs(shift)
55
- elif axis == 2:
56
- pad_idx = 4 if shift > 0 else 5
57
- pad_dims[pad_idx] = abs(shift)
58
-
59
- padded = F.pad(val.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
60
-
61
- slice_dims = [slice(None)] * 3
62
- if axis == 0:
63
- if shift > 0: slice_dims[0] = slice(shift, None)
64
- else: slice_dims[0] = slice(None, shift)
65
- elif axis == 1:
66
- if shift > 0: slice_dims[1] = slice(shift, None)
67
- else: slice_dims[1] = slice(None, shift)
68
- elif axis == 2:
69
- if shift > 0: slice_dims[2] = slice(shift, None)
70
- else: slice_dims[2] = slice(None, shift)
71
-
72
- padded = padded.squeeze(0).squeeze(0)
73
- neighbor = padded[slice_dims]
74
- neighbor = torch.where(neighbor > -9000, neighbor, val)
75
-
76
- # Check sign consistency
77
- neighbor_sign = torch.sign(neighbor.to(torch.float32))
78
- return (neighbor_sign == sign)
79
-
80
- # Iteratively check neighbors and update mask
81
- # directions: (shift, axis)
82
- directions = [(1, 0), (-1, 0), (1, 1), (-1, 1), (1, 2), (-1, 2)]
83
-
84
- for shift, axis in directions:
85
- is_same = check_neighbor_sign(shift, axis)
86
- mask = mask & is_same.to(torch.int32)
87
-
88
- # Invert mask: we want 1 where ANY neighbor has different sign
89
- mask = (~(mask.bool())).to(torch.int32)
90
- return mask * valid_mask.to(torch.int32)
91
-
92
-
93
- def generate_dense_grid_points(
94
- bbox_min: np.ndarray,
95
- bbox_max: np.ndarray,
96
- octree_resolution: int,
97
- indexing: str = "ij",
98
- ):
99
- length = bbox_max - bbox_min
100
- num_cells = octree_resolution
101
-
102
- x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
103
- y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
104
- z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
105
- [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
106
- xyz = np.stack((xs, ys, zs), axis=-1)
107
- grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
108
-
109
- return xyz, grid_size, length
110
-
111
-
112
- class VanillaVolumeDecoder:
113
- @torch.no_grad()
114
- def __call__(
115
- self,
116
- latents: torch.FloatTensor,
117
- geo_decoder: Callable,
118
- bounds: Union[Tuple[float], List[float], float] = 1.01,
119
- num_chunks: int = 10000,
120
- octree_resolution: int = None,
121
- enable_pbar: bool = True,
122
- **kwargs,
123
- ):
124
- device = latents.device
125
- dtype = latents.dtype
126
- batch_size = latents.shape[0]
127
-
128
- # 1. generate query points
129
- if isinstance(bounds, float):
130
- bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
131
-
132
- bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
133
- xyz_samples, grid_size, length = generate_dense_grid_points(
134
- bbox_min=bbox_min,
135
- bbox_max=bbox_max,
136
- octree_resolution=octree_resolution,
137
- indexing="ij"
138
- )
139
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
140
-
141
- # 2. latents to 3d volume
142
- batch_logits = []
143
- for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
144
- disable=not enable_pbar):
145
- chunk_queries = xyz_samples[start: start + num_chunks, :]
146
- chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
147
- logits = geo_decoder(queries=chunk_queries, latents=latents)
148
- batch_logits.append(logits)
149
-
150
- grid_logits = torch.cat(batch_logits, dim=1)
151
- grid_logits = grid_logits.view((batch_size, *grid_size)).float()
152
-
153
- return grid_logits
154
-
155
-
156
- class HierarchicalVolumeDecoding:
157
- @torch.no_grad()
158
- def __call__(
159
- self,
160
- latents: torch.FloatTensor,
161
- geo_decoder: Callable,
162
- bounds: Union[Tuple[float], List[float], float] = 1.01,
163
- num_chunks: int = 10000,
164
- mc_level: float = 0.0,
165
- octree_resolution: int = None,
166
- min_resolution: int = 63,
167
- enable_pbar: bool = True,
168
- **kwargs,
169
- ):
170
- device = latents.device
171
- dtype = latents.dtype
172
-
173
- resolutions = []
174
- if octree_resolution < min_resolution:
175
- resolutions.append(octree_resolution)
176
- while octree_resolution >= min_resolution:
177
- resolutions.append(octree_resolution)
178
- octree_resolution = octree_resolution // 2
179
- resolutions.reverse()
180
-
181
- # 1. generate query points
182
- if isinstance(bounds, float):
183
- bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
184
- bbox_min = np.array(bounds[0:3])
185
- bbox_max = np.array(bounds[3:6])
186
- bbox_size = bbox_max - bbox_min
187
-
188
- xyz_samples, grid_size, length = generate_dense_grid_points(
189
- bbox_min=bbox_min,
190
- bbox_max=bbox_max,
191
- octree_resolution=resolutions[0],
192
- indexing="ij"
193
- )
194
-
195
- dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
196
- dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
197
-
198
- grid_size = np.array(grid_size)
199
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
200
-
201
- # 2. latents to 3d volume
202
- batch_logits = []
203
- batch_size = latents.shape[0]
204
- for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
205
- desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
206
- queries = xyz_samples[start: start + num_chunks, :]
207
- batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
208
- logits = geo_decoder(queries=batch_queries, latents=latents)
209
- batch_logits.append(logits)
210
-
211
- grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
212
-
213
- for octree_depth_now in resolutions[1:]:
214
- grid_size = np.array([octree_depth_now + 1] * 3)
215
- resolution = bbox_size / octree_depth_now
216
- next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
217
- next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
218
- curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
219
- curr_points += grid_logits.squeeze(0).abs() < 0.95
220
-
221
- if octree_depth_now == resolutions[-1]:
222
- expand_num = 0
223
- else:
224
- expand_num = 1
225
- for i in range(expand_num):
226
- curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
227
- (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
228
- next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
229
- for i in range(2 - expand_num):
230
- next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
231
- nidx = torch.where(next_index > 0)
232
-
233
- # Store shape before deleting
234
- next_index_shape = next_index.shape
235
- del next_index
236
- torch.cuda.empty_cache()
237
-
238
- next_points = torch.stack(nidx, dim=1)
239
- next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
240
- torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
241
- batch_logits = []
242
- for start in tqdm(range(0, next_points.shape[0], num_chunks),
243
- desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
244
- queries = next_points[start: start + num_chunks, :]
245
- batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
246
- logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
247
- batch_logits.append(logits)
248
-
249
- # Delayed allocation of next_logits
250
- next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
251
- grid_logits = torch.cat(batch_logits, dim=1)
252
- next_logits[nidx] = grid_logits[0, ..., 0]
253
- grid_logits = next_logits.unsqueeze(0)
254
- grid_logits[grid_logits == -10000.] = float('nan')
255
-
256
- return grid_logits
257
-
258
-
259
- class FlashVDMVolumeDecoding:
260
- def __init__(self, topk_mode='mean'):
261
- if topk_mode not in ['mean', 'merge']:
262
- raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
263
-
264
- if topk_mode == 'mean':
265
- self.processor = FlashVDMCrossAttentionProcessor()
266
- else:
267
- self.processor = FlashVDMTopMCrossAttentionProcessor()
268
-
269
- @torch.no_grad()
270
- def __call__(
271
- self,
272
- latents: torch.FloatTensor,
273
- geo_decoder: CrossAttentionDecoder,
274
- bounds: Union[Tuple[float], List[float], float] = 1.01,
275
- num_chunks: int = 10000,
276
- mc_level: float = 0.0,
277
- octree_resolution: int = None,
278
- min_resolution: int = 63,
279
- mini_grid_num: int = 4,
280
- enable_pbar: bool = True,
281
- **kwargs,
282
- ):
283
- processor = self.processor
284
- geo_decoder.set_cross_attention_processor(processor)
285
-
286
- device = latents.device
287
- dtype = latents.dtype
288
-
289
- resolutions = []
290
- if octree_resolution < min_resolution:
291
- resolutions.append(octree_resolution)
292
- while octree_resolution >= min_resolution:
293
- resolutions.append(octree_resolution)
294
- octree_resolution = octree_resolution // 2
295
- resolutions.reverse()
296
- resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
297
- for i, resolution in enumerate(resolutions[1:]):
298
- resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
299
-
300
- logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
301
-
302
- # 1. generate query points
303
- if isinstance(bounds, float):
304
- bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
305
- bbox_min = np.array(bounds[0:3])
306
- bbox_max = np.array(bounds[3:6])
307
- bbox_size = bbox_max - bbox_min
308
-
309
- xyz_samples, grid_size, length = generate_dense_grid_points(
310
- bbox_min=bbox_min,
311
- bbox_max=bbox_max,
312
- octree_resolution=resolutions[0],
313
- indexing="ij"
314
- )
315
-
316
- dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
317
- dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
318
-
319
- grid_size = np.array(grid_size)
320
-
321
- # 2. latents to 3d volume
322
- xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
323
- batch_size = latents.shape[0]
324
- mini_grid_size = xyz_samples.shape[0] // mini_grid_num
325
- xyz_samples = xyz_samples.view(
326
- mini_grid_num, mini_grid_size,
327
- mini_grid_num, mini_grid_size,
328
- mini_grid_num, mini_grid_size, 3
329
- ).permute(
330
- 0, 2, 4, 1, 3, 5, 6
331
- ).reshape(
332
- -1, mini_grid_size * mini_grid_size * mini_grid_size, 3
333
- )
334
- batch_logits = []
335
- num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
336
- for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
337
- desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
338
- queries = xyz_samples[start: start + num_batchs, :]
339
- batch = queries.shape[0]
340
- batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
341
- processor.topk = True
342
-
343
- # Chunk queries along dim 1 if too large
344
- if queries.shape[1] > num_chunks:
345
- batch_logits_sub = []
346
- for sub_start in range(0, queries.shape[1], num_chunks):
347
- sub_queries = queries[:, sub_start: sub_start + num_chunks, :]
348
- logits = geo_decoder(queries=sub_queries, latents=batch_latents)
349
- batch_logits_sub.append(logits)
350
- logits = torch.cat(batch_logits_sub, dim=1)
351
- else:
352
- logits = geo_decoder(queries=queries, latents=batch_latents)
353
-
354
- batch_logits.append(logits)
355
- grid_logits = torch.cat(batch_logits, dim=0).reshape(
356
- mini_grid_num, mini_grid_num, mini_grid_num,
357
- mini_grid_size, mini_grid_size,
358
- mini_grid_size
359
- ).permute(0, 3, 1, 4, 2, 5).contiguous().view(
360
- (batch_size, grid_size[0], grid_size[1], grid_size[2])
361
- )
362
-
363
- for octree_depth_now in resolutions[1:]:
364
- grid_size = np.array([octree_depth_now + 1] * 3)
365
- resolution = bbox_size / octree_depth_now
366
- next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
367
- curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
368
- curr_points += grid_logits.squeeze(0).abs() < 0.95
369
-
370
- if octree_depth_now == resolutions[-1]:
371
- expand_num = 0
372
- else:
373
- expand_num = 1
374
- for i in range(expand_num):
375
- curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
376
- curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
377
- (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
378
-
379
- next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
380
- for i in range(2 - expand_num):
381
- next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
382
- nidx = torch.where(next_index > 0)
383
-
384
- # Store shape before deleting
385
- next_index_shape = next_index.shape
386
- del next_index
387
- torch.cuda.empty_cache()
388
-
389
- next_points = torch.stack(nidx, dim=1)
390
- next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
391
- torch.tensor(bbox_min, dtype=torch.float32, device=device))
392
-
393
- query_grid_num = 6
394
- min_val = next_points.min(axis=0).values
395
- max_val = next_points.max(axis=0).values
396
- vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
397
- index = torch.floor(vol_queries_index).long()
398
- index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
399
- index = index.sort()
400
- next_points = next_points[index.indices].unsqueeze(0).contiguous()
401
- unique_values = torch.unique(index.values, return_counts=True)
402
- grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
403
- input_grid = [[], []]
404
- logits_grid_list = []
405
- start_num = 0
406
- sum_num = 0
407
- for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
408
- remaining_count = count
409
- while remaining_count > 0:
410
- space_left = num_chunks - sum_num
411
- # If buffer is full, flush it
412
- if space_left <= 0:
413
- processor.topk = input_grid
414
- logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
415
- start_num = start_num + sum_num
416
- logits_grid_list.append(logits_grid)
417
- input_grid = [[], []]
418
- sum_num = 0
419
- space_left = num_chunks
420
-
421
- take = min(remaining_count, space_left)
422
- input_grid[0].append(grid_index)
423
- input_grid[1].append(take)
424
- sum_num += take
425
- remaining_count -= take
426
- if sum_num > 0:
427
- processor.topk = input_grid
428
- logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
429
- logits_grid_list.append(logits_grid)
430
- logits_grid = torch.cat(logits_grid_list, dim=1)
431
- grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
432
-
433
- # Delayed allocation of next_logits
434
- next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
435
- next_logits[nidx] = grid_logits
436
- grid_logits = next_logits.unsqueeze(0)
437
-
438
- grid_logits[grid_logits == -10000.] = float('nan')
439
-
440
- return grid_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/conditioner_mask.py DELETED
@@ -1,337 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Open Source Model Licensed under the Apache License Version 2.0
9
- # and Other Licenses of the Third-Party Components therein:
10
- # The below Model in this distribution may have been modified by THL A29 Limited
11
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
12
-
13
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
14
- # The below software and/or models in this distribution may have been
15
- # modified by THL A29 Limited ("Tencent Modifications").
16
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
17
-
18
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
19
- # except for the third-party components listed below.
20
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
21
- # in the repsective licenses of these third-party components.
22
- # Users must comply with all terms and conditions of original licenses of these third-party
23
- # components and must ensure that the usage of the third party components adheres to
24
- # all relevant laws and regulations.
25
-
26
- # For avoidance of doubts, Hunyuan 3D means the large language models and
27
- # their software and algorithms, including trained model weights, parameters (including
28
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
29
- # fine-tuning enabling code and other elements of the foregoing made publicly available
30
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
31
-
32
-
33
- import numpy as np
34
- import torch
35
- import torch.nn as nn
36
- from torchvision import transforms
37
- from transformers import (
38
- CLIPVisionModelWithProjection,
39
- CLIPVisionConfig,
40
- Dinov2Model,
41
- Dinov2Config,
42
- )
43
- from transformers import AutoImageProcessor, AutoModel
44
-
45
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
46
- """
47
- embed_dim: output dimension for each position
48
- pos: a list of positions to be encoded: size (M,)
49
- out: (M, D)
50
- """
51
- assert embed_dim % 2 == 0
52
- omega = np.arange(embed_dim // 2, dtype=np.float64)
53
- omega /= embed_dim / 2.
54
- omega = 1. / 10000 ** omega # (D/2,)
55
-
56
- pos = pos.reshape(-1) # (M,)
57
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
58
-
59
- emb_sin = np.sin(out) # (M, D/2)
60
- emb_cos = np.cos(out) # (M, D/2)
61
-
62
- return np.concatenate([emb_sin, emb_cos], axis=1)
63
-
64
-
65
- class ImageEncoder(nn.Module):
66
- def __init__(
67
- self,
68
- version=None,
69
- config=None,
70
- use_cls_token=True,
71
- image_size=224,
72
- **kwargs,
73
- ):
74
- super().__init__()
75
-
76
- if config is None:
77
- self.model = AutoModel.from_pretrained(version)
78
- else:
79
- self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
80
-
81
- self.model.eval()
82
- self.model.requires_grad_(False)
83
- self.use_cls_token = use_cls_token
84
- self.size = image_size // 14
85
- self.num_patches = (image_size // 14) ** 2
86
- if self.use_cls_token:
87
- self.num_patches += 1
88
-
89
- self.transform = transforms.Compose(
90
- [
91
- transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
92
- transforms.CenterCrop(image_size),
93
- transforms.Normalize(
94
- mean=self.mean,
95
- std=self.std,
96
- ),
97
- ]
98
- )
99
-
100
- self.mask_transform = transforms.Compose(
101
- [
102
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
103
- transforms.CenterCrop(image_size),
104
- ]
105
- )
106
-
107
- def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
108
- if value_range is not None:
109
- low, high = value_range
110
- image = (image - low) / (high - low)
111
-
112
- image = image.to(self.model.device, dtype=self.model.dtype)
113
- inputs = self.transform(image)
114
- outputs = self.model(inputs)
115
-
116
- last_hidden_state = outputs.last_hidden_state
117
- if not self.use_cls_token:
118
- last_hidden_state = last_hidden_state[:, 1:, :]
119
-
120
- if mask is not None:
121
- pool = nn.MaxPool2d(kernel_size=(14, 14), stride=(14, 14))
122
-
123
- mask = self.mask_transform(mask)
124
- mask = mask.to(image.device, dtype=image.dtype)
125
- downsampled_mask = pool(mask)
126
- flattened_mask = downsampled_mask.view(downsampled_mask.shape[0], -1)
127
- flattened_mask = flattened_mask.unsqueeze(-1)
128
-
129
- if self.use_cls_token:
130
- flattened_mask = torch.cat(
131
- [torch.ones(flattened_mask.shape[0], 1, 1, device=flattened_mask.device, dtype=flattened_mask.dtype),
132
- flattened_mask], dim=1)
133
-
134
- valid_mask = (flattened_mask != -1).float()
135
- masked_hidden_state = last_hidden_state * valid_mask
136
- valid_mask_bool = valid_mask.squeeze(-1) > 0
137
-
138
- valid_counts = valid_mask_bool.sum(dim=1)
139
- max_valid_tokens = valid_counts.max().item()
140
-
141
- batch_indices = torch.arange(valid_mask_bool.shape[0], device=valid_mask_bool.device)
142
- batch_indices = batch_indices.unsqueeze(1).expand(-1, valid_mask_bool.shape[1])
143
-
144
- flat_batch_indices = batch_indices[valid_mask_bool]
145
- flat_token_indices = torch.arange(valid_mask_bool.shape[1], device=valid_mask_bool.device)
146
- flat_token_indices = flat_token_indices.unsqueeze(0).expand(valid_mask_bool.shape[0], -1)
147
- flat_token_indices = flat_token_indices[valid_mask_bool]
148
-
149
- valid_tokens = masked_hidden_state[flat_batch_indices, flat_token_indices]
150
- # Create output tensor with special padding value (-1) instead of zeros
151
- final_output = torch.full(
152
- (valid_mask_bool.shape[0], max_valid_tokens, last_hidden_state.shape[-1]),
153
- -1.0, # Use -1 as padding value to clearly distinguish from valid tokens
154
- device=last_hidden_state.device, dtype=last_hidden_state.dtype
155
- )
156
-
157
- cum_counts = torch.cumsum(valid_counts, dim=0) - valid_counts
158
- for i in range(valid_mask_bool.shape[0]):
159
- if valid_counts[i] > 0:
160
- start_idx = cum_counts[i]
161
- end_idx = start_idx + valid_counts[i]
162
- final_output[i, :valid_counts[i]] = valid_tokens[start_idx:end_idx]
163
-
164
- return final_output
165
-
166
- return last_hidden_state
167
-
168
- def unconditional_embedding(self, batch_size, **kwargs):
169
- device = next(self.model.parameters()).device
170
- dtype = next(self.model.parameters()).dtype
171
-
172
- num_tokens = kwargs.get('num_tokens', self.num_patches)
173
-
174
- zero = torch.zeros(
175
- batch_size,
176
- num_tokens,
177
- self.model.config.hidden_size,
178
- device=device,
179
- dtype=dtype,
180
- )
181
-
182
- return zero
183
-
184
-
185
- class CLIPImageEncoder(ImageEncoder):
186
- MODEL_CLASS = CLIPVisionModelWithProjection
187
- MODEL_CONFIG_CLASS = CLIPVisionConfig
188
- mean = [0.48145466, 0.4578275, 0.40821073]
189
- std = [0.26862954, 0.26130258, 0.27577711]
190
-
191
-
192
- class DinoImageEncoder(ImageEncoder):
193
- MODEL_CLASS = Dinov2Model
194
- MODEL_CONFIG_CLASS = Dinov2Config
195
- mean = [0.485, 0.456, 0.406]
196
- std = [0.229, 0.224, 0.225]
197
-
198
- class DinoImageEncoderMV(DinoImageEncoder):
199
- def __init__(
200
- self,
201
- version=None,
202
- config=None,
203
- use_cls_token=True,
204
- image_size=224,
205
- view_num=4,
206
- **kwargs,
207
- ):
208
- super().__init__(version, config, use_cls_token, image_size, **kwargs)
209
- self.view_num = view_num
210
- self.num_patches = self.num_patches
211
- pos = np.arange(self.view_num, dtype=np.float32)
212
- view_embedding = torch.from_numpy(
213
- get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
214
-
215
- view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
216
- self.view_embed = view_embedding.unsqueeze(0)
217
-
218
- def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
219
- if value_range is not None:
220
- low, high = value_range
221
- image = (image - low) / (high - low)
222
-
223
- image = image.to(self.model.device, dtype=self.model.dtype)
224
-
225
- bs, num_views, c, h, w = image.shape
226
- image = image.view(bs * num_views, c, h, w)
227
-
228
- inputs = self.transform(image)
229
- outputs = self.model(inputs)
230
-
231
- last_hidden_state = outputs.last_hidden_state
232
- last_hidden_state = last_hidden_state.view(
233
- bs, num_views, last_hidden_state.shape[-2],
234
- last_hidden_state.shape[-1]
235
- )
236
-
237
- view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
238
- if view_idxs is not None:
239
- assert len(view_idxs) == bs
240
- view_embeddings = []
241
- for i in range(bs):
242
- view_idx = view_idxs[i]
243
- assert num_views == len(view_idx)
244
- view_embeddings.append(self.view_embed[:, view_idx, ...])
245
- view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
246
-
247
- if num_views != self.view_num:
248
- view_embedding = view_embedding[:, :num_views, ...]
249
- last_hidden_state = last_hidden_state + view_embedding
250
- last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
251
- last_hidden_state.shape[-1])
252
- return last_hidden_state
253
-
254
- def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
255
- device = next(self.model.parameters()).device
256
- dtype = next(self.model.parameters()).dtype
257
- zero = torch.zeros(
258
- batch_size,
259
- self.num_patches * len(view_idxs[0]),
260
- self.model.config.hidden_size,
261
- device=device,
262
- dtype=dtype,
263
- )
264
- return zero
265
-
266
-
267
- def build_image_encoder(config):
268
- if config['type'] == 'CLIPImageEncoder':
269
- return CLIPImageEncoder(**config['kwargs'])
270
- elif config['type'] == 'DinoImageEncoder':
271
- return DinoImageEncoder(**config['kwargs'])
272
- elif config['type'] == 'DinoImageEncoderMV':
273
- return DinoImageEncoderMV(**config['kwargs'])
274
- else:
275
- raise ValueError(f'Unknown image encoder type: {config["type"]}')
276
-
277
-
278
- class DualImageEncoder(nn.Module):
279
- def __init__(
280
- self,
281
- main_image_encoder,
282
- additional_image_encoder,
283
- ):
284
- super().__init__()
285
- self.main_image_encoder = build_image_encoder(main_image_encoder)
286
- self.additional_image_encoder = build_image_encoder(additional_image_encoder)
287
-
288
- def forward(self, image, mask=None, **kwargs):
289
- outputs = {
290
- 'main': self.main_image_encoder(image, mask=mask, **kwargs),
291
- 'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
292
- }
293
- return outputs
294
-
295
- def unconditional_embedding(self, batch_size, **kwargs):
296
- outputs = {
297
- 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
298
- 'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
299
- }
300
- return outputs
301
-
302
-
303
- class SingleImageEncoder(nn.Module):
304
- def __init__(
305
- self,
306
- main_image_encoder,
307
- drop_ratio=0.1,
308
- ):
309
- super().__init__()
310
- self.main_image_encoder = build_image_encoder(main_image_encoder)
311
- self.drop_ratio = drop_ratio
312
- # self.disable_drop = disable_drop
313
-
314
- def forward(self, image, disable_drop=True, mask=None, **kwargs):
315
- outputs = {
316
- 'main': self.main_image_encoder(image, mask=mask, **kwargs),
317
- }
318
-
319
- if disable_drop:
320
- return outputs
321
- else:
322
- random_p = torch.rand(len(image), device='cuda')
323
- remain_bool_tensor = random_p > self.drop_ratio
324
- outputs['main'] *= remain_bool_tensor.view(-1,1,1)
325
- return outputs
326
-
327
-
328
- outputs = {
329
- 'main': self.main_image_encoder(image, mask=mask, **kwargs),
330
- }
331
- return outputs
332
-
333
- def unconditional_embedding(self, batch_size, **kwargs):
334
- outputs = {
335
- 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
336
- }
337
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/denoisers/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- from .dit_mask import RefineDiT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/denoisers/dit_mask.py DELETED
@@ -1,725 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Open Source Model Licensed under the Apache License Version 2.0
9
- # and Other Licenses of the Third-Party Components therein:
10
- # The below Model in this distribution may have been modified by THL A29 Limited
11
- # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
12
-
13
- # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
14
- # The below software and/or models in this distribution may have been
15
- # modified by THL A29 Limited ("Tencent Modifications").
16
- # All Tencent Modifications are Copyright (C) THL A29 Limited.
17
-
18
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
19
- # except for the third-party components listed below.
20
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
21
- # in the repsective licenses of these third-party components.
22
- # Users must comply with all terms and conditions of original licenses of these third-party
23
- # components and must ensure that the usage of the third party components adheres to
24
- # all relevant laws and regulations.
25
-
26
- # For avoidance of doubts, Hunyuan 3D means the large language models and
27
- # their software and algorithms, including trained model weights, parameters (including
28
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
29
- # fine-tuning enabling code and other elements of the foregoing made publicly available
30
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
31
-
32
- import os
33
- import yaml
34
- import math
35
-
36
- import numpy as np
37
- import torch
38
- import torch.nn as nn
39
- import torch.nn.functional as F
40
- from einops import rearrange
41
-
42
- from .moe_layers import MoEBlock
43
- from ...utils import logger, synchronize_timer, smart_load_model
44
-
45
- from flash_attn import flash_attn_varlen_func
46
-
47
-
48
- def modulate(x, shift, scale):
49
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
-
51
-
52
- class Timesteps(nn.Module):
53
- def __init__(self,
54
- num_channels: int,
55
- downscale_freq_shift: float = 0.0,
56
- scale: int = 1,
57
- max_period: int = 10000
58
- ):
59
- super().__init__()
60
- self.num_channels = num_channels
61
- self.downscale_freq_shift = downscale_freq_shift
62
- self.scale = scale
63
- self.max_period = max_period
64
-
65
- def forward(self, timesteps):
66
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
67
- embedding_dim = self.num_channels
68
- half_dim = embedding_dim // 2
69
- exponent = -math.log(self.max_period) * torch.arange(
70
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
71
- exponent = exponent / (half_dim - self.downscale_freq_shift)
72
- emb = torch.exp(exponent)
73
- emb = timesteps[:, None].float() * emb[None, :]
74
- emb = self.scale * emb
75
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
76
- if embedding_dim % 2 == 1:
77
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
78
- return emb
79
-
80
-
81
- class TimestepEmbedder(nn.Module):
82
- """
83
- Embeds scalar timesteps into vector representations.
84
- """
85
-
86
- def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):
87
- super().__init__()
88
- if out_size is None:
89
- out_size = hidden_size
90
- self.mlp = nn.Sequential(
91
- nn.Linear(hidden_size, frequency_embedding_size, bias=True),
92
- nn.GELU(),
93
- nn.Linear(frequency_embedding_size, out_size, bias=True),
94
- )
95
- self.frequency_embedding_size = frequency_embedding_size
96
-
97
- if cond_proj_dim is not None:
98
- self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)
99
-
100
- self.time_embed = Timesteps(hidden_size)
101
-
102
- def forward(self, t, condition):
103
-
104
- t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
105
-
106
- # t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
107
- if condition is not None:
108
- t_freq = t_freq + self.cond_proj(condition)
109
-
110
- t = self.mlp(t_freq)
111
- t = t.unsqueeze(dim=1)
112
- return t
113
-
114
-
115
- class MLP(nn.Module):
116
- def __init__(self, *, width: int):
117
- super().__init__()
118
- self.width = width
119
- self.fc1 = nn.Linear(width, width * 4)
120
- self.fc2 = nn.Linear(width * 4, width)
121
- self.gelu = nn.GELU()
122
-
123
- def forward(self, x):
124
- return self.fc2(self.gelu(self.fc1(x)))
125
-
126
-
127
- class CrossAttention(nn.Module):
128
- def __init__(
129
- self,
130
- qdim,
131
- kdim,
132
- num_heads,
133
- qkv_bias=True,
134
- qk_norm=False,
135
- norm_layer=nn.LayerNorm,
136
- **kwargs,
137
- ):
138
- super().__init__()
139
- self.qdim = qdim
140
- self.kdim = kdim
141
- self.num_heads = num_heads
142
- assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
143
- self.head_dim = self.qdim // num_heads
144
- assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
145
- self.scale = self.head_dim ** -0.5
146
-
147
- self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
148
- self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
149
- self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
150
-
151
- # TODO: eps should be 1 / 65530 if using fp16
152
- self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
153
- self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
154
- self.out_proj = nn.Linear(qdim, qdim, bias=True)
155
-
156
-
157
- def forward(self, x, y):
158
- """
159
- Parameters
160
- ----------
161
- x: torch.Tensor
162
- (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
163
- y: torch.Tensor
164
- (batch, seqlen2, hidden_dim2) - may contain padding (marked with -1)
165
- freqs_cis_img: torch.Tensor
166
- (batch, hidden_dim // 2), RoPE for image
167
- """
168
- b, s1, c = x.shape # [b, s1, D]
169
-
170
- # Detect padding tokens: check if all values in the feature dimension are -1
171
- # y_mask: [b, s2], True for valid tokens, False for padding
172
- y_mask = (y != -1).any(dim=-1) # [b, s2]
173
- has_padding = not y_mask.all()
174
-
175
- _, s2, c = y.shape # [b, s2, 1024]
176
- q = self.to_q(x)
177
- k = self.to_k(y)
178
- v = self.to_v(y)
179
-
180
- kv = torch.cat((k, v), dim=-1)
181
- split_size = kv.shape[-1] // self.num_heads // 2
182
- kv = kv.view(1, -1, self.num_heads, split_size * 2)
183
- k, v = torch.split(kv, split_size, dim=-1)
184
-
185
- q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
186
- k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
187
- v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
188
-
189
- q = self.q_norm(q)
190
- k = self.k_norm(k)
191
-
192
- if has_padding:
193
- seqlens_k = y_mask.sum(dim=1).int()
194
- q_flat = q.reshape(-1, self.num_heads, self.head_dim)
195
-
196
- # For k, v: only keep valid tokens (remove padding)
197
- # Create indices for valid tokens
198
- valid_indices = []
199
- cu_seqlens_k = [0]
200
- for i in range(b):
201
- valid_len = seqlens_k[i].item()
202
- batch_indices = torch.arange(valid_len, device=y.device) + i * s2
203
- valid_indices.append(batch_indices)
204
- cu_seqlens_k.append(cu_seqlens_k[-1] + valid_len)
205
-
206
- valid_indices = torch.cat(valid_indices)
207
- k_flat = k.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
208
- v_flat = v.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
209
-
210
- # Create cumulative sequence lengths
211
- cu_seqlens_q = torch.arange(0, (b + 1) * s1, s1, dtype=torch.int32, device=x.device)
212
- cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=x.device)
213
-
214
- # Call flash attention varlen
215
- q_flat = q_flat.to(torch.bfloat16)
216
- k_flat = k_flat.to(torch.bfloat16)
217
- v_flat = v_flat.to(torch.bfloat16)
218
-
219
- context = flash_attn_varlen_func(
220
- q_flat, k_flat, v_flat,
221
- cu_seqlens_q, cu_seqlens_k,
222
- s1, seqlens_k.max().item(),
223
- dropout_p=0.0,
224
- softmax_scale=None,
225
- causal=False
226
- )
227
- context = context.reshape(b, s1, -1)
228
- else:
229
- with torch.backends.cuda.sdp_kernel(
230
- enable_flash=True,
231
- enable_math=False,
232
- enable_mem_efficient=True
233
- ):
234
- q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))
235
-
236
- attn_mask = None
237
- context = F.scaled_dot_product_attention(
238
- q, k, v, attn_mask=attn_mask
239
- ).transpose(1, 2).reshape(b, s1, -1)
240
-
241
- out = self.out_proj(context)
242
-
243
- return out
244
-
245
-
246
- class Attention(nn.Module):
247
- """
248
- We rename some layer names to align with flash attention
249
- """
250
-
251
- def __init__(
252
- self,
253
- dim,
254
- num_heads,
255
- qkv_bias=True,
256
- qk_norm=False,
257
- norm_layer=nn.LayerNorm,
258
- ):
259
- super().__init__()
260
- self.dim = dim
261
- self.num_heads = num_heads
262
- assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
263
- self.head_dim = self.dim // num_heads
264
- # This assertion is aligned with flash attention
265
- assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
266
- self.scale = self.head_dim ** -0.5
267
-
268
- self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
269
- self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
270
- self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
271
- self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
272
- self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
273
- self.out_proj = nn.Linear(dim, dim)
274
-
275
- # def forward(self, x):
276
- def forward(self, x, rotary_cos=None, rotary_sin=None):
277
- B, N, C = x.shape
278
-
279
- q = self.to_q(x)
280
- k = self.to_k(x)
281
- v = self.to_v(x)
282
-
283
- qkv = torch.cat((q, k, v), dim=-1)
284
- split_size = qkv.shape[-1] // self.num_heads // 3
285
- qkv = qkv.view(1, -1, self.num_heads, split_size * 3)
286
- q, k, v = torch.split(qkv, split_size, dim=-1)
287
-
288
- q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
289
- k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
290
- v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
291
-
292
- q = self.q_norm(q) # [b, h, s, d]
293
- k = self.k_norm(k) # [b, h, s, d]
294
-
295
- # ========================= Apply RoPE =========================
296
- if rotary_cos is not None:
297
- q = apply_rotary_emb(q, rotary_cos, rotary_sin)
298
- k = apply_rotary_emb(k, rotary_cos, rotary_sin)
299
- # ==============================================================
300
-
301
- with torch.backends.cuda.sdp_kernel(
302
- enable_flash=True,
303
- enable_math=False,
304
- enable_mem_efficient=True
305
- ):
306
- x = F.scaled_dot_product_attention(q, k, v)
307
- x = x.transpose(1, 2).reshape(B, N, -1)
308
-
309
- x = self.out_proj(x)
310
- return x
311
-
312
-
313
- class DiTBlock(nn.Module):
314
- def __init__(
315
- self,
316
- hidden_size,
317
- c_emb_size,
318
- num_heads,
319
- text_states_dim=1024,
320
- use_flash_attn=False,
321
- qk_norm=False,
322
- norm_layer=nn.LayerNorm,
323
- qk_norm_layer=nn.RMSNorm,
324
- init_scale=1.0,
325
- qkv_bias=True,
326
- skip_connection=True,
327
- timested_modulate=False,
328
- use_moe: bool = False,
329
- num_experts: int = 8,
330
- moe_top_k: int = 2,
331
- **kwargs,
332
- ):
333
- super().__init__()
334
- self.use_flash_attn = use_flash_attn
335
- use_ele_affine = True
336
-
337
- # ========================= Self-Attention =========================
338
- self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
339
- self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
340
- norm_layer=qk_norm_layer)
341
-
342
- # ========================= FFN =========================
343
- self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
344
-
345
- # ========================= Add =========================
346
- # Simply use add like SDXL.
347
- self.timested_modulate = timested_modulate
348
- if self.timested_modulate:
349
- self.default_modulation = nn.Sequential(
350
- nn.SiLU(),
351
- nn.Linear(c_emb_size, hidden_size, bias=True)
352
- )
353
-
354
- # ========================= Cross-Attention =========================
355
- self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
356
- qk_norm=qk_norm, norm_layer=qk_norm_layer, init_scale=init_scale)
357
- self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
358
-
359
- if skip_connection:
360
- self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
361
- self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
362
- else:
363
- self.skip_linear = None
364
-
365
- self.use_moe = use_moe
366
- if self.use_moe:
367
- self.moe = MoEBlock(
368
- hidden_size,
369
- num_experts=num_experts,
370
- moe_top_k=moe_top_k,
371
- dropout=0.0,
372
- activation_fn="gelu",
373
- final_dropout=False,
374
- ff_inner_dim=int(hidden_size * 4.0),
375
- ff_bias=True,
376
- )
377
- else:
378
- self.mlp = MLP(width=hidden_size)
379
-
380
- def forward(self, x, c=None, text_states=None, skip_value=None, rotary_cos=None, rotary_sin=None):
381
-
382
- if self.skip_linear is not None:
383
- cat = torch.cat([skip_value, x], dim=-1)
384
- x = self.skip_linear(cat)
385
- x = self.skip_norm(x)
386
-
387
- # Self-Attention
388
- if self.timested_modulate:
389
- shift_msa = self.default_modulation(c).unsqueeze(dim=1)
390
- x = x + shift_msa
391
-
392
- attn_out = self.attn1(self.norm1(x), rotary_cos=rotary_cos, rotary_sin=rotary_sin)
393
-
394
- x = x + attn_out
395
-
396
- # Cross-Attention
397
- x = x + self.attn2(self.norm2(x), text_states)
398
-
399
- # FFN Layer
400
- mlp_inputs = self.norm3(x)
401
-
402
- if self.use_moe:
403
- x = x + self.moe(mlp_inputs)
404
- else:
405
- x = x + self.mlp(mlp_inputs)
406
-
407
- return x
408
-
409
-
410
- class AttentionPool(nn.Module):
411
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
412
- super().__init__()
413
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
414
- self.k_proj = nn.Linear(embed_dim, embed_dim)
415
- self.q_proj = nn.Linear(embed_dim, embed_dim)
416
- self.v_proj = nn.Linear(embed_dim, embed_dim)
417
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
418
- self.num_heads = num_heads
419
-
420
- def forward(self, x, attention_mask=None):
421
- x = x.permute(1, 0, 2) # NLC -> LNC
422
- if attention_mask is not None:
423
- attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
424
- global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
425
- x = torch.cat([global_emb[None,], x], dim=0)
426
-
427
- else:
428
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
429
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
430
- x, _ = F.multi_head_attention_forward(
431
- query=x[:1], key=x, value=x,
432
- embed_dim_to_check=x.shape[-1],
433
- num_heads=self.num_heads,
434
- q_proj_weight=self.q_proj.weight,
435
- k_proj_weight=self.k_proj.weight,
436
- v_proj_weight=self.v_proj.weight,
437
- in_proj_weight=None,
438
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
439
- bias_k=None,
440
- bias_v=None,
441
- add_zero_attn=False,
442
- dropout_p=0,
443
- out_proj_weight=self.c_proj.weight,
444
- out_proj_bias=self.c_proj.bias,
445
- use_separate_proj_weight=True,
446
- training=self.training,
447
- need_weights=False
448
- )
449
- return x.squeeze(0)
450
-
451
-
452
- class FinalLayer(nn.Module):
453
- """
454
- The final layer of DiT.
455
- """
456
-
457
- def __init__(self, final_hidden_size, out_channels):
458
- super().__init__()
459
- self.final_hidden_size = final_hidden_size
460
- self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)
461
- self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
462
-
463
- def forward(self, x):
464
- x = self.norm_final(x)
465
- x = x[:, 1:]
466
- x = self.linear(x)
467
- return x
468
-
469
-
470
- class RefineDiT(nn.Module):
471
-
472
- @classmethod
473
- @synchronize_timer('Refine Model Loading')
474
- def from_single_file(
475
- cls,
476
- ckpt_path,
477
- config_path,
478
- device='cuda',
479
- dtype=torch.float16,
480
- use_safetensors=None,
481
- **kwargs,
482
- ):
483
- # load config
484
- with open(config_path, 'r') as f:
485
- config = yaml.safe_load(f)
486
-
487
- # load ckpt
488
- if use_safetensors:
489
- ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
490
- if not os.path.exists(ckpt_path):
491
- raise FileNotFoundError(f"Model file {ckpt_path} not found")
492
-
493
- logger.info(f"Loading model from {ckpt_path}")
494
- if use_safetensors:
495
- import safetensors.torch
496
- ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
497
- else:
498
- ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
499
-
500
- if 'model' in ckpt:
501
- ckpt = ckpt['model']
502
- if 'model' in config:
503
- config = config['model']
504
-
505
- model_kwargs = config['params']
506
- model_kwargs.update(kwargs)
507
-
508
- model = cls(**model_kwargs)
509
- model.load_state_dict(ckpt)
510
- model.to(device=device, dtype=dtype)
511
- return model
512
-
513
- @classmethod
514
- def from_pretrained(
515
- cls,
516
- model_path,
517
- device='cuda',
518
- dtype=torch.float16,
519
- use_safetensors=False,
520
- variant='fp16',
521
- subfolder='hunyuan3d-dit-v2-1',
522
- **kwargs,
523
- ):
524
- config_path, ckpt_path = smart_load_model(
525
- model_path,
526
- subfolder=subfolder,
527
- use_safetensors=use_safetensors,
528
- variant=variant
529
- )
530
-
531
- return cls.from_single_file(
532
- ckpt_path,
533
- config_path,
534
- device=device,
535
- dtype=dtype,
536
- use_safetensors=use_safetensors,
537
- **kwargs
538
- )
539
-
540
- def __init__(
541
- self,
542
- input_size=1024,
543
- in_channels=4,
544
- hidden_size=1024,
545
- context_dim=1024,
546
- depth=24,
547
- num_heads=16,
548
- mlp_ratio=4.0,
549
- norm_type='layer',
550
- qk_norm_type='rms',
551
- qk_norm=False,
552
- text_len=257,
553
- guidance_cond_proj_dim=None,
554
- qkv_bias=True,
555
- num_moe_layers: int = 6,
556
- num_experts: int = 8,
557
- moe_top_k: int = 2,
558
- voxel_query_res: int = 128,
559
- **kwargs
560
- ):
561
- super().__init__()
562
- self.input_size = input_size
563
- self.depth = depth
564
- self.in_channels = in_channels
565
- self.out_channels = in_channels
566
- self.num_heads = num_heads
567
-
568
- self.hidden_size = hidden_size
569
- self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm
570
- self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm
571
- self.context_dim = context_dim
572
- self.voxel_query_res = voxel_query_res
573
-
574
- self.guidance_cond_proj_dim = guidance_cond_proj_dim
575
-
576
- self.text_len = text_len
577
-
578
- self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
579
- self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)
580
-
581
- self.blocks = nn.ModuleList([
582
- DiTBlock(hidden_size=hidden_size,
583
- c_emb_size=hidden_size,
584
- num_heads=num_heads,
585
- mlp_ratio=mlp_ratio,
586
- text_states_dim=context_dim,
587
- qk_norm=qk_norm,
588
- norm_layer=self.norm,
589
- qk_norm_layer=self.qk_norm,
590
- skip_connection=layer > depth // 2,
591
- qkv_bias=qkv_bias,
592
- use_moe=True if depth - layer <= num_moe_layers else False,
593
- num_experts=num_experts,
594
- moe_top_k=moe_top_k
595
- )
596
- for layer in range(depth)
597
- ])
598
- self.depth = depth
599
-
600
- self.final_layer = FinalLayer(hidden_size, self.out_channels)
601
-
602
- def forward(self, x, t, contexts, **kwargs):
603
- cond = contexts['main']
604
-
605
- t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))
606
- x = self.x_embedder(x)
607
- c = t
608
-
609
- ##########################################
610
- head_dim = self.blocks[0].attn1.head_dim
611
- num_cond_tokens = c.shape[1] if c.dim() == 3 else 1
612
-
613
- device = x.device
614
- cond_cos = torch.ones(x.shape[0], num_cond_tokens, head_dim, device=device)
615
- cond_sin = torch.zeros(x.shape[0], num_cond_tokens, head_dim, device=device)
616
-
617
- voxel_cond = kwargs.get('voxel_cond')
618
- # rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d(head_dim, voxel_cond)
619
- rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d_interpolated(
620
- head_dim, voxel_cond, current_res=self.voxel_query_res)
621
-
622
- rotary_cos = torch.cat([cond_cos, rotary_cos_vox], dim=1)
623
- rotary_sin = torch.cat([cond_sin, rotary_sin_vox], dim=1)
624
- ##########################################
625
-
626
- x = torch.cat([c, x], dim=1)
627
-
628
- skip_value_list = []
629
- for layer, block in enumerate(self.blocks):
630
- skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
631
- x = block(x, c, cond, rotary_cos=rotary_cos, rotary_sin=rotary_sin, skip_value=skip_value)
632
- if layer < self.depth // 2:
633
- skip_value_list.append(x)
634
-
635
- x = self.final_layer(x)
636
- return x
637
-
638
-
639
- def apply_rotary_emb(x, cos, sin):
640
- """
641
- x: [B, H, N, D]
642
- cos, sin: [B, N, D]
643
- """
644
-
645
- cos = cos.unsqueeze(1)
646
- sin = sin.unsqueeze(1)
647
-
648
- def rotate_half(x):
649
- x1, x2 = x.chunk(2, dim=-1)
650
- return torch.cat((-x2, x1), dim=-1)
651
-
652
- return (x * cos) + (rotate_half(x) * sin)
653
-
654
-
655
- def precompute_freqs_cis_3d(dim: int, grid_indices: torch.Tensor, theta: float = 10000.0):
656
- """
657
- grid_indices: [B, N, 3] voxel idx
658
- """
659
- dim_x = dim // 3
660
- dim_y = dim // 3
661
- dim_z = dim - dim_x - dim_y
662
-
663
- device = grid_indices.device
664
- freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
665
- freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
666
- freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
667
-
668
- x_idx = grid_indices[..., 0].float()
669
- y_idx = grid_indices[..., 1].float()
670
- z_idx = grid_indices[..., 2].float()
671
-
672
- args_x = x_idx.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
673
- args_y = y_idx.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
674
- args_z = z_idx.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
675
-
676
- args = torch.cat([args_x, args_y, args_z], dim=-1)
677
- args = torch.cat([args, args], dim=-1)
678
-
679
- return torch.cos(args), torch.sin(args)
680
-
681
-
682
- def precompute_freqs_cis_3d_interpolated(
683
- dim: int,
684
- grid_indices: torch.Tensor,
685
- theta: float = 10000.0,
686
- trained_res: float = 128.0, # training resolution
687
- current_res: float = 256.0, # inference resolution
688
- ):
689
- scale_factor = current_res / trained_res
690
-
691
- dim_x = dim // 3
692
- dim_y = dim // 3
693
- dim_z = dim - dim_x - dim_y
694
-
695
- device = grid_indices.device
696
-
697
- freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
698
- freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
699
- freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
700
-
701
- num_freqs_x = dim_x // 2 + (dim_x % 2)
702
- num_freqs_y = dim_y // 2 + (dim_y % 2)
703
- target_len = dim // 2
704
- freqs_x = freqs_x[:num_freqs_x]
705
- freqs_y = freqs_y[:num_freqs_y]
706
- freqs_z = freqs_z[:(target_len - len(freqs_x) - len(freqs_y))]
707
-
708
- input_x = grid_indices[..., 0].float()
709
- input_y = grid_indices[..., 1].float()
710
- input_z = grid_indices[..., 2].float()
711
-
712
- # Apply Scaling
713
- pos_x = input_x / scale_factor
714
- pos_y = input_y / scale_factor
715
- pos_z = input_z / scale_factor
716
-
717
- # pos * freq
718
- args_x = pos_x.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
719
- args_y = pos_y.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
720
- args_z = pos_z.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
721
-
722
- args = torch.cat([args_x, args_y, args_z], dim=-1)
723
- args = torch.cat([args, args], dim=-1)
724
-
725
- return torch.cos(args), torch.sin(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/denoisers/moe_layers.py DELETED
@@ -1,177 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import torch
16
- import torch.nn as nn
17
- import numpy as np
18
- import math
19
- from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
20
-
21
- import torch.nn.functional as F
22
- from diffusers.models.attention import FeedForward
23
-
24
- class AddAuxiliaryLoss(torch.autograd.Function):
25
- """
26
- The trick function of adding auxiliary (aux) loss,
27
- which includes the gradient of the aux loss during backpropagation.
28
- """
29
- @staticmethod
30
- def forward(ctx, x, loss):
31
- assert loss.numel() == 1
32
- ctx.dtype = loss.dtype
33
- ctx.required_aux_loss = loss.requires_grad
34
- return x
35
-
36
- @staticmethod
37
- def backward(ctx, grad_output):
38
- grad_loss = None
39
- if ctx.required_aux_loss:
40
- grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
41
- return grad_output, grad_loss
42
-
43
- class MoEGate(nn.Module):
44
- def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
45
- super().__init__()
46
- self.top_k = num_experts_per_tok
47
- self.n_routed_experts = num_experts
48
-
49
- self.scoring_func = 'softmax'
50
- self.alpha = aux_loss_alpha
51
- self.seq_aux = False
52
-
53
- # topk selection algorithm
54
- self.norm_topk_prob = False
55
- self.gating_dim = embed_dim
56
- self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
57
- self.reset_parameters()
58
-
59
- def reset_parameters(self) -> None:
60
- import torch.nn.init as init
61
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
62
-
63
- def forward(self, hidden_states):
64
- bsz, seq_len, h = hidden_states.shape
65
- # print(bsz, seq_len, h)
66
- ### compute gating score
67
- hidden_states = hidden_states.view(-1, h)
68
- logits = F.linear(hidden_states, self.weight, None)
69
- if self.scoring_func == 'softmax':
70
- scores = logits.softmax(dim=-1)
71
- else:
72
- raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
73
-
74
- ### select top-k experts
75
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
76
-
77
- ### norm gate to sum 1
78
- if self.top_k > 1 and self.norm_topk_prob:
79
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
80
- topk_weight = topk_weight / denominator
81
-
82
- ### expert-level computation auxiliary loss
83
- if self.training and self.alpha > 0.0:
84
- scores_for_aux = scores
85
- aux_topk = self.top_k
86
- # always compute aux loss based on the naive greedy topk method
87
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
88
- if self.seq_aux:
89
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
90
- ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
91
- ce.scatter_add_(
92
- 1,
93
- topk_idx_for_aux_loss,
94
- torch.ones(
95
- bsz, seq_len * aux_topk,
96
- device=hidden_states.device
97
- )
98
- ).div_(seq_len * aux_topk / self.n_routed_experts)
99
- aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()
100
- aux_loss = aux_loss * self.alpha
101
- else:
102
- mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),
103
- num_classes=self.n_routed_experts)
104
- ce = mask_ce.float().mean(0)
105
- Pi = scores_for_aux.mean(0)
106
- fi = ce * self.n_routed_experts
107
- aux_loss = (Pi * fi).sum() * self.alpha
108
- else:
109
- aux_loss = None
110
- return topk_idx, topk_weight, aux_loss
111
-
112
- class MoEBlock(nn.Module):
113
- def __init__(self, dim, num_experts=8, moe_top_k=2,
114
- activation_fn = "gelu", dropout=0.0, final_dropout = False,
115
- ff_inner_dim = None, ff_bias = True):
116
- super().__init__()
117
- self.moe_top_k = moe_top_k
118
- self.experts = nn.ModuleList([
119
- FeedForward(dim,dropout=dropout,
120
- activation_fn=activation_fn,
121
- final_dropout=final_dropout,
122
- inner_dim=ff_inner_dim,
123
- bias=ff_bias)
124
- for i in range(num_experts)])
125
- self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)
126
-
127
- self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,
128
- final_dropout=final_dropout, inner_dim=ff_inner_dim,
129
- bias=ff_bias)
130
-
131
- def initialize_weight(self):
132
- pass
133
-
134
- def forward(self, hidden_states):
135
- identity = hidden_states
136
- orig_shape = hidden_states.shape
137
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
138
-
139
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
140
- flat_topk_idx = topk_idx.view(-1)
141
- if self.training:
142
- hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
143
- y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
144
- for i, expert in enumerate(self.experts):
145
- tmp = expert(hidden_states[flat_topk_idx == i])
146
- y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
147
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
148
- y = y.view(*orig_shape)
149
- y = AddAuxiliaryLoss.apply(y, aux_loss)
150
- else:
151
- y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
152
- y = y + self.shared_experts(identity)
153
- return y
154
-
155
-
156
- @torch.no_grad()
157
- def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
158
- expert_cache = torch.zeros_like(x)
159
- idxs = flat_expert_indices.argsort()
160
- tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
161
- token_idxs = idxs // self.moe_top_k
162
- for i, end_idx in enumerate(tokens_per_expert):
163
- start_idx = 0 if i == 0 else tokens_per_expert[i-1]
164
- if start_idx == end_idx:
165
- continue
166
- expert = self.experts[i]
167
- exp_token_idx = token_idxs[start_idx:end_idx]
168
- expert_tokens = x[exp_token_idx]
169
- expert_out = expert(expert_tokens)
170
- expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
171
-
172
- # for fp16 and other dtype
173
- expert_cache = expert_cache.to(expert_out.dtype)
174
- expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
175
- expert_out,
176
- reduce='sum')
177
- return expert_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/flow_matching_dit_trainer.py DELETED
@@ -1,313 +0,0 @@
1
-
2
- import os
3
- from contextlib import contextmanager
4
- from typing import List, Tuple, Optional, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- from torch.optim import lr_scheduler
9
- import pytorch_lightning as pl
10
- from pytorch_lightning.utilities import rank_zero_info
11
- from pytorch_lightning.utilities import rank_zero_only
12
- from ultrashape.pipelines import export_to_trimesh
13
-
14
- from ...utils.ema import LitEma
15
- from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model, instantiate_vae_model_local
16
-
17
-
18
- class Diffuser(pl.LightningModule):
19
- def __init__(
20
- self,
21
- *,
22
- vae_config,
23
- cond_config,
24
- dit_cfg,
25
- scheduler_cfg,
26
- optimizer_cfg,
27
- pipeline_cfg=None,
28
- image_processor_cfg=None,
29
- lora_config=None,
30
- ema_config=None,
31
- scale_by_std: bool = False,
32
- z_scale_factor: float = 1.0,
33
- ckpt_path: Optional[str] = None,
34
- ignore_keys: Union[Tuple[str], List[str]] = (),
35
- torch_compile: bool = False,
36
- ):
37
- super().__init__()
38
-
39
- # ========= init optimizer config ========= #
40
- self.optimizer_cfg = optimizer_cfg
41
-
42
- # ========= init diffusion scheduler ========= #
43
- self.scheduler_cfg = scheduler_cfg
44
- self.sampler = None
45
- if 'transport' in scheduler_cfg:
46
- self.transport = instantiate_from_config(scheduler_cfg.transport)
47
- self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)
48
- self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)
49
-
50
- # ========= init the model ========= #
51
- self.dit_cfg = dit_cfg
52
- self.model = instantiate_from_config(dit_cfg, device=None, dtype=None)
53
-
54
- self.cond_stage_model = instantiate_from_config(cond_config)
55
-
56
- self.ckpt_path = ckpt_path
57
- if ckpt_path is not None:
58
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
59
-
60
- # ========= config lora model ========= #
61
- if lora_config is not None:
62
- from peft import LoraConfig, get_peft_model
63
- loraconfig = LoraConfig(
64
- r=lora_config.rank,
65
- lora_alpha=lora_config.rank,
66
- target_modules=lora_config.get('target_modules')
67
- )
68
- self.model = get_peft_model(self.model, loraconfig)
69
-
70
- # ========= config ema model ========= #
71
- self.ema_config = ema_config
72
- if self.ema_config is not None:
73
- if self.ema_config.ema_model == 'DSEma':
74
- # from michelangelo.models.modules.ema_deepspeed import DSEma
75
- from ..utils.ema_deepspeed import DSEma
76
- self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)
77
- else:
78
- self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)
79
- #do not initilize EMA weight from ckpt path, since I need to change moe layers
80
- if ckpt_path is not None:
81
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
82
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
83
-
84
- # ========= init vae at last to prevent it is overridden by loaded ckpt ========= #
85
- self.first_stage_model = instantiate_vae_model_local(vae_config)
86
- self.first_stage_model.enable_flashvdm_decoder()
87
-
88
- self.scale_by_std = scale_by_std
89
- if scale_by_std:
90
- self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
91
- else:
92
- self.z_scale_factor = z_scale_factor
93
-
94
- # ========= init pipeline for inference ========= #
95
- self.image_processor_cfg = image_processor_cfg
96
- self.image_processor = None
97
- if self.image_processor_cfg is not None:
98
- self.image_processor = instantiate_from_config(self.image_processor_cfg)
99
- self.pipeline_cfg = pipeline_cfg
100
- from ...schedulers import FlowMatchEulerDiscreteScheduler
101
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
102
- self.pipeline = instantiate_from_config(
103
- pipeline_cfg,
104
- vae=self.first_stage_model,
105
- model=self.model,
106
- scheduler=scheduler,
107
- conditioner=self.cond_stage_model,
108
- image_processor=self.image_processor,
109
- )
110
-
111
- # ========= torch compile to accelerate ========= #
112
- self.torch_compile = torch_compile
113
- if self.torch_compile:
114
- torch.nn.Module.compile(self.model)
115
- torch.nn.Module.compile(self.first_stage_model)
116
- torch.nn.Module.compile(self.cond_stage_model)
117
- print(f'*' * 100)
118
- print(f'Compile model for acceleration')
119
- print(f'*' * 100)
120
-
121
- @contextmanager
122
- def ema_scope(self, context=None):
123
- if self.ema_config is not None and self.ema_config.get('ema_inference', False):
124
- self.model_ema.store(self.model)
125
- self.model_ema.copy_to(self.model)
126
- if context is not None:
127
- print(f"{context}: Switched to EMA weights")
128
- try:
129
- yield None
130
- finally:
131
- if self.ema_config is not None and self.ema_config.get('ema_inference', False):
132
- self.model_ema.restore(self.model)
133
- if context is not None:
134
- print(f"{context}: Restored training weights")
135
-
136
- def init_from_ckpt(self, path, ignore_keys=()):
137
- ckpt = torch.load(path, map_location="cpu")
138
- if 'state_dict' not in ckpt:
139
- # deepspeed ckpt
140
- state_dict = {}
141
- for k in ckpt.keys():
142
- new_k = k.replace('_forward_module.', '')
143
- state_dict[new_k] = ckpt[k]
144
- else:
145
- state_dict = ckpt["state_dict"]
146
-
147
- keys = list(state_dict.keys())
148
- for k in keys:
149
- for ik in ignore_keys:
150
- if ik in k:
151
- print("Deleting key {} from state_dict.".format(k))
152
- del state_dict[k]
153
-
154
- missing, unexpected = self.load_state_dict(state_dict, strict=False)
155
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
156
- if len(missing) > 0:
157
- print(f"Missing Keys: {missing}")
158
- print(f"Unexpected Keys: {unexpected}")
159
-
160
- def on_load_checkpoint(self, checkpoint):
161
- """
162
- The pt_model is trained separately, so we already have access to its
163
- checkpoint and load it separately with `self.set_pt_model`.
164
-
165
- However, the PL Trainer is strict about
166
- checkpoint loading (not configurable), so it expects the loaded state_dict
167
- to match exactly the keys in the model state_dict.
168
-
169
- So, when loading the checkpoint, before matching keys, we add all pt_model keys
170
- from self.state_dict() to the checkpoint state dict, so that they match
171
- """
172
- for key in self.state_dict().keys():
173
- if key.startswith("model_ema") and key not in checkpoint["state_dict"]:
174
- checkpoint["state_dict"][key] = self.state_dict()[key]
175
-
176
- def configure_optimizers(self) -> Tuple[List, List]:
177
- lr = self.learning_rate
178
-
179
- params_list = []
180
- trainable_parameters = list(self.model.parameters())
181
- params_list.append({'params': trainable_parameters, 'lr': lr})
182
-
183
- no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']
184
-
185
-
186
- if self.optimizer_cfg.get('train_image_encoder', False):
187
- image_encoder_parameters = list(self.cond_stage_model.named_parameters())
188
- image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if
189
- not any((no_decay_name in name) for no_decay_name in no_decay)]
190
- image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if
191
- any((no_decay_name in name) for no_decay_name in no_decay)]
192
- # filter trainable params
193
- image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if
194
- param.requires_grad]
195
- image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if
196
- param.requires_grad]
197
-
198
- print(f"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, ")
199
- print(f"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, ")
200
-
201
- image_encoder_lr = self.optimizer_cfg['image_encoder_lr']
202
- image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)
203
- image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply
204
- params_list.append(
205
- {'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,
206
- 'weight_decay': 0.05})
207
- params_list.append(
208
- {'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,
209
- 'weight_decay': 0.})
210
-
211
- optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
212
- if hasattr(self.optimizer_cfg, 'scheduler'):
213
- scheduler_func = instantiate_from_config(
214
- self.optimizer_cfg.scheduler,
215
- max_decay_steps=self.trainer.max_steps,
216
- lr_max=lr
217
- )
218
- scheduler = {
219
- "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
220
- "interval": "step",
221
- "frequency": 1
222
- }
223
- schedulers = [scheduler]
224
- else:
225
- schedulers = []
226
- optimizers = [optimizer]
227
-
228
- return optimizers, schedulers
229
-
230
-
231
- def on_train_batch_end(self, *args, **kwargs):
232
- if self.ema_config is not None:
233
- self.model_ema(self.model)
234
-
235
- def on_train_epoch_start(self) -> None:
236
- pl.seed_everything(self.trainer.global_rank)
237
-
238
- def forward(self, batch, disable_drop):
239
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
240
- contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'), disable_drop=disable_drop)
241
-
242
- with torch.autocast(device_type="cuda", dtype=torch.float16):
243
- with torch.no_grad():
244
- latents, voxel_idx = self.first_stage_model.encode(batch["surface"], sample_posterior=True, need_voxel=True)
245
- latents = self.z_scale_factor * latents
246
- # print(latents.shape)
247
-
248
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
249
- loss = self.transport.training_losses(self.model, latents,
250
- dict(contexts=contexts, voxel_cond=voxel_idx))["loss"].mean()
251
-
252
- return loss
253
-
254
- def training_step(self, batch, batch_idx, optimizer_idx=0):
255
- loss = self.forward(batch, disable_drop=False)
256
- split = 'train'
257
- loss_dict = {
258
- f"{split}/total_loss": loss.detach(),
259
- f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
260
- }
261
- self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
262
-
263
- return loss
264
-
265
- def validation_step(self, batch, batch_idx, optimizer_idx=0):
266
- loss = self.forward(batch, disable_drop=True)
267
- split = 'val'
268
- loss_dict = {
269
- f"{split}/total_loss": loss.detach(),
270
- f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
271
- }
272
- self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
273
-
274
- return loss
275
-
276
- @torch.no_grad()
277
- def sample(self, batch, output_type='trimesh', **kwargs):
278
- self.cond_stage_model.disable_drop = True
279
-
280
- generator = torch.Generator().manual_seed(0)
281
-
282
- with self.ema_scope("Sample"):
283
- with torch.amp.autocast(device_type='cuda'):
284
- try:
285
- self.pipeline.device = self.device
286
- self.pipeline.dtype = self.dtype
287
- print("### USING PIPELINE ###")
288
- print(f'device: {self.device} dtype : {self.dtype}')
289
- additional_params = {'output_type':output_type}
290
-
291
- image = batch.get("image", None)
292
- mask = batch.get('mask', None)
293
-
294
- outputs = self.pipeline(image=image,
295
- mask=mask,
296
- generator=generator,
297
- box_v=1.0,
298
- mc_level=0.0,
299
- octree_resolution=1024,
300
- **additional_params)
301
-
302
- except Exception as e:
303
- import traceback
304
- traceback.print_exc()
305
- print(f"Unexpected {e=}, {type(e)=}")
306
- with open("error.txt", "a") as f:
307
- f.write(str(e))
308
- f.write(traceback.format_exc())
309
- f.write("\n")
310
- outputs = [None]
311
-
312
- self.cond_stage_model.disable_drop = False
313
- return [outputs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/transport/__init__.py DELETED
@@ -1,97 +0,0 @@
1
- # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
- # which is licensed under the MIT License.
3
- #
4
- # MIT License
5
- #
6
- # Copyright (c) Meta Platforms, Inc. and affiliates.
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
- from .transport import Transport, ModelType, WeightType, PathType, Sampler
27
-
28
-
29
- def create_transport(
30
- path_type='Linear',
31
- prediction="velocity",
32
- loss_weight=None,
33
- train_eps=None,
34
- sample_eps=None,
35
- train_sample_type="uniform",
36
- mean = 0.0,
37
- std = 1.0,
38
- shift_scale = 1.0,
39
- ):
40
- """function for creating Transport object
41
- **Note**: model prediction defaults to velocity
42
- Args:
43
- - path_type: type of path to use; default to linear
44
- - learn_score: set model prediction to score
45
- - learn_noise: set model prediction to noise
46
- - velocity_weighted: weight loss by velocity weight
47
- - likelihood_weighted: weight loss by likelihood weight
48
- - train_eps: small epsilon for avoiding instability during training
49
- - sample_eps: small epsilon for avoiding instability during sampling
50
- """
51
-
52
- if prediction == "noise":
53
- model_type = ModelType.NOISE
54
- elif prediction == "score":
55
- model_type = ModelType.SCORE
56
- else:
57
- model_type = ModelType.VELOCITY
58
-
59
- if loss_weight == "velocity":
60
- loss_type = WeightType.VELOCITY
61
- elif loss_weight == "likelihood":
62
- loss_type = WeightType.LIKELIHOOD
63
- else:
64
- loss_type = WeightType.NONE
65
-
66
- path_choice = {
67
- "Linear": PathType.LINEAR,
68
- "GVP": PathType.GVP,
69
- "VP": PathType.VP,
70
- }
71
-
72
- path_type = path_choice[path_type]
73
-
74
- if (path_type in [PathType.VP]):
75
- train_eps = 1e-5 if train_eps is None else train_eps
76
- sample_eps = 1e-3 if train_eps is None else sample_eps
77
- elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
78
- train_eps = 1e-3 if train_eps is None else train_eps
79
- sample_eps = 1e-3 if train_eps is None else sample_eps
80
- else: # velocity & [GVP, LINEAR] is stable everywhere
81
- train_eps = 0
82
- sample_eps = 0
83
-
84
- # create flow state
85
- state = Transport(
86
- model_type=model_type,
87
- path_type=path_type,
88
- loss_type=loss_type,
89
- train_eps=train_eps,
90
- sample_eps=sample_eps,
91
- train_sample_type=train_sample_type,
92
- mean=mean,
93
- std=std,
94
- shift_scale =shift_scale,
95
- )
96
-
97
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/transport/integrators.py DELETED
@@ -1,142 +0,0 @@
1
- # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
- # which is licensed under the MIT License.
3
- #
4
- # MIT License
5
- #
6
- # Copyright (c) Meta Platforms, Inc. and affiliates.
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
- import numpy as np
27
- import torch as th
28
- import torch.nn as nn
29
- from torchdiffeq import odeint
30
- from functools import partial
31
- from tqdm import tqdm
32
-
33
- class sde:
34
- """SDE solver class"""
35
- def __init__(
36
- self,
37
- drift,
38
- diffusion,
39
- *,
40
- t0,
41
- t1,
42
- num_steps,
43
- sampler_type,
44
- ):
45
- assert t0 < t1, "SDE sampler has to be in forward time"
46
-
47
- self.num_timesteps = num_steps
48
- self.t = th.linspace(t0, t1, num_steps)
49
- self.dt = self.t[1] - self.t[0]
50
- self.drift = drift
51
- self.diffusion = diffusion
52
- self.sampler_type = sampler_type
53
-
54
- def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
55
- w_cur = th.randn(x.size()).to(x)
56
- t = th.ones(x.size(0)).to(x) * t
57
- dw = w_cur * th.sqrt(self.dt)
58
- drift = self.drift(x, t, model, **model_kwargs)
59
- diffusion = self.diffusion(x, t)
60
- mean_x = x + drift * self.dt
61
- x = mean_x + th.sqrt(2 * diffusion) * dw
62
- return x, mean_x
63
-
64
- def __Heun_step(self, x, _, t, model, **model_kwargs):
65
- w_cur = th.randn(x.size()).to(x)
66
- dw = w_cur * th.sqrt(self.dt)
67
- t_cur = th.ones(x.size(0)).to(x) * t
68
- diffusion = self.diffusion(x, t_cur)
69
- xhat = x + th.sqrt(2 * diffusion) * dw
70
- K1 = self.drift(xhat, t_cur, model, **model_kwargs)
71
- xp = xhat + self.dt * K1
72
- K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
73
- return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
74
-
75
- def __forward_fn(self):
76
- """TODO: generalize here by adding all private functions ending with steps to it"""
77
- sampler_dict = {
78
- "Euler": self.__Euler_Maruyama_step,
79
- "Heun": self.__Heun_step,
80
- }
81
-
82
- try:
83
- sampler = sampler_dict[self.sampler_type]
84
- except:
85
- raise NotImplementedError("Smapler type not implemented.")
86
-
87
- return sampler
88
-
89
- def sample(self, init, model, **model_kwargs):
90
- """forward loop of sde"""
91
- x = init
92
- mean_x = init
93
- samples = []
94
- sampler = self.__forward_fn()
95
- for ti in self.t[:-1]:
96
- with th.no_grad():
97
- x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
98
- samples.append(x)
99
-
100
- return samples
101
-
102
- class ode:
103
- """ODE solver class"""
104
- def __init__(
105
- self,
106
- drift,
107
- *,
108
- t0,
109
- t1,
110
- sampler_type,
111
- num_steps,
112
- atol,
113
- rtol,
114
- ):
115
- assert t0 < t1, "ODE sampler has to be in forward time"
116
-
117
- self.drift = drift
118
- self.t = th.linspace(t0, t1, num_steps)
119
- self.atol = atol
120
- self.rtol = rtol
121
- self.sampler_type = sampler_type
122
-
123
- def sample(self, x, model, **model_kwargs):
124
-
125
- device = x[0].device if isinstance(x, tuple) else x.device
126
- def _fn(t, x):
127
- t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
128
- model_output = self.drift(x, t, model, **model_kwargs)
129
- return model_output
130
-
131
- t = self.t.to(device)
132
- atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
133
- rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
134
- samples = odeint(
135
- _fn,
136
- x,
137
- t,
138
- method=self.sampler_type,
139
- atol=atol,
140
- rtol=rtol
141
- )
142
- return samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/transport/path.py DELETED
@@ -1,220 +0,0 @@
1
- # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
- # which is licensed under the MIT License.
3
- #
4
- # MIT License
5
- #
6
- # Copyright (c) Meta Platforms, Inc. and affiliates.
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
- import torch as th
27
- import numpy as np
28
- from functools import partial
29
-
30
- def expand_t_like_x(t, x):
31
- """Function to reshape time t to broadcastable dimension of x
32
- Args:
33
- t: [batch_dim,], time vector
34
- x: [batch_dim,...], data point
35
- """
36
- dims = [1] * (len(x.size()) - 1)
37
- t = t.view(t.size(0), *dims)
38
- return t
39
-
40
-
41
- #################### Coupling Plans ####################
42
-
43
- class ICPlan:
44
- """Linear Coupling Plan"""
45
- def __init__(self, sigma=0.0):
46
- self.sigma = sigma
47
-
48
- def compute_alpha_t(self, t):
49
- """Compute the data coefficient along the path"""
50
- return t, 1
51
-
52
- def compute_sigma_t(self, t):
53
- """Compute the noise coefficient along the path"""
54
- return 1 - t, -1
55
-
56
- def compute_d_alpha_alpha_ratio_t(self, t):
57
- """Compute the ratio between d_alpha and alpha"""
58
- return 1 / t
59
-
60
- def compute_drift(self, x, t):
61
- """We always output sde according to score parametrization; """
62
- t = expand_t_like_x(t, x)
63
- alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
64
- sigma_t, d_sigma_t = self.compute_sigma_t(t)
65
- drift = alpha_ratio * x
66
- diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
67
-
68
- return -drift, diffusion
69
-
70
- def compute_diffusion(self, x, t, form="constant", norm=1.0):
71
- """Compute the diffusion term of the SDE
72
- Args:
73
- x: [batch_dim, ...], data point
74
- t: [batch_dim,], time vector
75
- form: str, form of the diffusion term
76
- norm: float, norm of the diffusion term
77
- """
78
- t = expand_t_like_x(t, x)
79
- choices = {
80
- "constant": norm,
81
- "SBDM": norm * self.compute_drift(x, t)[1],
82
- "sigma": norm * self.compute_sigma_t(t)[0],
83
- "linear": norm * (1 - t),
84
- "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
85
- "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
86
- }
87
-
88
- try:
89
- diffusion = choices[form]
90
- except KeyError:
91
- raise NotImplementedError(f"Diffusion form {form} not implemented")
92
-
93
- return diffusion
94
-
95
- def get_score_from_velocity(self, velocity, x, t):
96
- """Wrapper function: transfrom velocity prediction model to score
97
- Args:
98
- velocity: [batch_dim, ...] shaped tensor; velocity model output
99
- x: [batch_dim, ...] shaped tensor; x_t data point
100
- t: [batch_dim,] time tensor
101
- """
102
- t = expand_t_like_x(t, x)
103
- alpha_t, d_alpha_t = self.compute_alpha_t(t)
104
- sigma_t, d_sigma_t = self.compute_sigma_t(t)
105
- mean = x
106
- reverse_alpha_ratio = alpha_t / d_alpha_t
107
- var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
108
- score = (reverse_alpha_ratio * velocity - mean) / var
109
- return score
110
-
111
- def get_noise_from_velocity(self, velocity, x, t):
112
- """Wrapper function: transfrom velocity prediction model to denoiser
113
- Args:
114
- velocity: [batch_dim, ...] shaped tensor; velocity model output
115
- x: [batch_dim, ...] shaped tensor; x_t data point
116
- t: [batch_dim,] time tensor
117
- """
118
- t = expand_t_like_x(t, x)
119
- alpha_t, d_alpha_t = self.compute_alpha_t(t)
120
- sigma_t, d_sigma_t = self.compute_sigma_t(t)
121
- mean = x
122
- reverse_alpha_ratio = alpha_t / d_alpha_t
123
- var = reverse_alpha_ratio * d_sigma_t - sigma_t
124
- noise = (reverse_alpha_ratio * velocity - mean) / var
125
- return noise
126
-
127
- def get_velocity_from_score(self, score, x, t):
128
- """Wrapper function: transfrom score prediction model to velocity
129
- Args:
130
- score: [batch_dim, ...] shaped tensor; score model output
131
- x: [batch_dim, ...] shaped tensor; x_t data point
132
- t: [batch_dim,] time tensor
133
- """
134
- t = expand_t_like_x(t, x)
135
- drift, var = self.compute_drift(x, t)
136
- velocity = var * score - drift
137
- return velocity
138
-
139
- def compute_mu_t(self, t, x0, x1):
140
- """Compute the mean of time-dependent density p_t"""
141
- t = expand_t_like_x(t, x1)
142
- alpha_t, _ = self.compute_alpha_t(t)
143
- sigma_t, _ = self.compute_sigma_t(t)
144
- # t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
145
- return alpha_t * x1 + sigma_t * x0
146
-
147
- def compute_xt(self, t, x0, x1):
148
- """Sample xt from time-dependent density p_t; rng is required"""
149
- xt = self.compute_mu_t(t, x0, x1)
150
- return xt
151
-
152
- def compute_ut(self, t, x0, x1, xt):
153
- """Compute the vector field corresponding to p_t"""
154
- t = expand_t_like_x(t, x1)
155
- _, d_alpha_t = self.compute_alpha_t(t)
156
- _, d_sigma_t = self.compute_sigma_t(t)
157
- return d_alpha_t * x1 + d_sigma_t * x0
158
-
159
- def plan(self, t, x0, x1):
160
- xt = self.compute_xt(t, x0, x1)
161
- ut = self.compute_ut(t, x0, x1, xt)
162
- return t, xt, ut
163
-
164
-
165
- class VPCPlan(ICPlan):
166
- """class for VP path flow matching"""
167
-
168
- def __init__(self, sigma_min=0.1, sigma_max=20.0):
169
- self.sigma_min = sigma_min
170
- self.sigma_max = sigma_max
171
- self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
172
- (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
173
- self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
174
- (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
175
-
176
-
177
- def compute_alpha_t(self, t):
178
- """Compute coefficient of x1"""
179
- alpha_t = self.log_mean_coeff(t)
180
- alpha_t = th.exp(alpha_t)
181
- d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
182
- return alpha_t, d_alpha_t
183
-
184
- def compute_sigma_t(self, t):
185
- """Compute coefficient of x0"""
186
- p_sigma_t = 2 * self.log_mean_coeff(t)
187
- sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
188
- d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
189
- return sigma_t, d_sigma_t
190
-
191
- def compute_d_alpha_alpha_ratio_t(self, t):
192
- """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
193
- return self.d_log_mean_coeff(t)
194
-
195
- def compute_drift(self, x, t):
196
- """Compute the drift term of the SDE"""
197
- t = expand_t_like_x(t, x)
198
- beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
199
- return -0.5 * beta_t * x, beta_t / 2
200
-
201
-
202
- class GVPCPlan(ICPlan):
203
- def __init__(self, sigma=0.0):
204
- super().__init__(sigma)
205
-
206
- def compute_alpha_t(self, t):
207
- """Compute coefficient of x1"""
208
- alpha_t = th.sin(t * np.pi / 2)
209
- d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
210
- return alpha_t, d_alpha_t
211
-
212
- def compute_sigma_t(self, t):
213
- """Compute coefficient of x0"""
214
- sigma_t = th.cos(t * np.pi / 2)
215
- d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
216
- return sigma_t, d_sigma_t
217
-
218
- def compute_d_alpha_alpha_ratio_t(self, t):
219
- """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
220
- return np.pi / (2 * th.tan(t * np.pi / 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/transport/transport.py DELETED
@@ -1,534 +0,0 @@
1
- # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
- # which is licensed under the MIT License.
3
- #
4
- # MIT License
5
- #
6
- # Copyright (c) Meta Platforms, Inc. and affiliates.
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
- import torch as th
27
- import numpy as np
28
- import logging
29
-
30
- import enum
31
-
32
- from . import path
33
- from .utils import EasyDict, log_state, mean_flat
34
- from .integrators import ode, sde
35
-
36
-
37
- class ModelType(enum.Enum):
38
- """
39
- Which type of output the model predicts.
40
- """
41
-
42
- NOISE = enum.auto() # the model predicts epsilon
43
- SCORE = enum.auto() # the model predicts \nabla \log p(x)
44
- VELOCITY = enum.auto() # the model predicts v(x)
45
-
46
-
47
- class PathType(enum.Enum):
48
- """
49
- Which type of path to use.
50
- """
51
-
52
- LINEAR = enum.auto()
53
- GVP = enum.auto()
54
- VP = enum.auto()
55
-
56
-
57
- class WeightType(enum.Enum):
58
- """
59
- Which type of weighting to use.
60
- """
61
-
62
- NONE = enum.auto()
63
- VELOCITY = enum.auto()
64
- LIKELIHOOD = enum.auto()
65
-
66
-
67
- class Transport:
68
-
69
- def __init__(
70
- self,
71
- *,
72
- model_type,
73
- path_type,
74
- loss_type,
75
- train_eps,
76
- sample_eps,
77
- train_sample_type = "uniform",
78
- **kwargs,
79
- ):
80
- path_options = {
81
- PathType.LINEAR: path.ICPlan,
82
- PathType.GVP: path.GVPCPlan,
83
- PathType.VP: path.VPCPlan,
84
- }
85
-
86
- self.loss_type = loss_type
87
- self.model_type = model_type
88
- self.path_sampler = path_options[path_type]()
89
- self.train_eps = train_eps
90
- self.sample_eps = sample_eps
91
- self.train_sample_type = train_sample_type
92
- if self.train_sample_type == "logit_normal":
93
- self.mean = kwargs['mean']
94
- self.std = kwargs['std']
95
- self.shift_scale = kwargs['shift_scale']
96
- print(f"using logit normal sample, shift scale is {self.shift_scale}")
97
-
98
- def prior_logp(self, z):
99
- '''
100
- Standard multivariate normal prior
101
- Assume z is batched
102
- '''
103
- shape = th.tensor(z.size())
104
- N = th.prod(shape[1:])
105
- _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
106
- return th.vmap(_fn)(z)
107
-
108
- def check_interval(
109
- self,
110
- train_eps,
111
- sample_eps,
112
- *,
113
- diffusion_form="SBDM",
114
- sde=False,
115
- reverse=False,
116
- eval=False,
117
- last_step_size=0.0,
118
- ):
119
- t0 = 0
120
- t1 = 1
121
- eps = train_eps if not eval else sample_eps
122
- if (type(self.path_sampler) in [path.VPCPlan]):
123
-
124
- t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
125
-
126
- elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
127
- and (
128
- self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
129
-
130
- t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
131
- t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
132
-
133
- if reverse:
134
- t0, t1 = 1 - t0, 1 - t1
135
-
136
- return t0, t1
137
-
138
- def sample(self, x1):
139
- """Sampling x0 & t based on shape of x1 (if needed)
140
- Args:
141
- x1 - data point; [batch, *dim]
142
- """
143
-
144
- x0 = th.randn_like(x1)
145
- if self.train_sample_type=="uniform":
146
- t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
147
- t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
148
- t = t.to(x1)
149
- elif self.train_sample_type=="logit_normal":
150
- t = th.randn((x1.shape[0],)) * self.std + self.mean
151
- t = t.to(x1)
152
- t = 1/(1+th.exp(-t))
153
-
154
- t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)
155
-
156
- return t, x0, x1
157
-
158
- def training_losses(
159
- self,
160
- model,
161
- x1,
162
- model_kwargs=None
163
- ):
164
- """Loss for training the score model
165
- Args:
166
- - model: backbone model; could be score, noise, or velocity
167
- - x1: datapoint
168
- - model_kwargs: additional arguments for the model
169
- """
170
- if model_kwargs == None:
171
- model_kwargs = {}
172
-
173
- t, x0, x1 = self.sample(x1)
174
- t, xt, ut = self.path_sampler.plan(t, x0, x1)
175
- model_output = model(xt, t, **model_kwargs)
176
- B, *_, C = xt.shape
177
- assert model_output.size() == (B, *xt.size()[1:-1], C)
178
-
179
- terms = {}
180
- terms['pred'] = model_output
181
- if self.model_type == ModelType.VELOCITY:
182
- terms['loss'] = mean_flat(((model_output - ut) ** 2))
183
- else:
184
- _, drift_var = self.path_sampler.compute_drift(xt, t)
185
- sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
186
- if self.loss_type in [WeightType.VELOCITY]:
187
- weight = (drift_var / sigma_t) ** 2
188
- elif self.loss_type in [WeightType.LIKELIHOOD]:
189
- weight = drift_var / (sigma_t ** 2)
190
- elif self.loss_type in [WeightType.NONE]:
191
- weight = 1
192
- else:
193
- raise NotImplementedError()
194
-
195
- if self.model_type == ModelType.NOISE:
196
- terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
197
- else:
198
- terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
199
-
200
- return terms
201
-
202
- def get_drift(
203
- self
204
- ):
205
- """member function for obtaining the drift of the probability flow ODE"""
206
-
207
- def score_ode(x, t, model, **model_kwargs):
208
- drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
209
- model_output = model(x, t, **model_kwargs)
210
- return (-drift_mean + drift_var * model_output) # by change of variable
211
-
212
- def noise_ode(x, t, model, **model_kwargs):
213
- drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
214
- sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
215
- model_output = model(x, t, **model_kwargs)
216
- score = model_output / -sigma_t
217
- return (-drift_mean + drift_var * score)
218
-
219
- def velocity_ode(x, t, model, **model_kwargs):
220
- model_output = model(x, t, **model_kwargs)
221
- return model_output
222
-
223
- if self.model_type == ModelType.NOISE:
224
- drift_fn = noise_ode
225
- elif self.model_type == ModelType.SCORE:
226
- drift_fn = score_ode
227
- else:
228
- drift_fn = velocity_ode
229
-
230
- def body_fn(x, t, model, **model_kwargs):
231
- model_output = drift_fn(x, t, model, **model_kwargs)
232
- assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
233
- return model_output
234
-
235
- return body_fn
236
-
237
- def get_score(
238
- self,
239
- ):
240
- """member function for obtaining score of
241
- x_t = alpha_t * x + sigma_t * eps"""
242
- if self.model_type == ModelType.NOISE:
243
- score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \
244
- self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
245
- elif self.model_type == ModelType.SCORE:
246
- score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
247
- elif self.model_type == ModelType.VELOCITY:
248
- score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,
249
- t)
250
- else:
251
- raise NotImplementedError()
252
-
253
- return score_fn
254
-
255
-
256
- class Sampler:
257
- """Sampler class for the transport model"""
258
-
259
- def __init__(
260
- self,
261
- transport,
262
- ):
263
- """Constructor for a general sampler; supporting different sampling methods
264
- Args:
265
- - transport: an tranport object specify model prediction & interpolant type
266
- """
267
-
268
- self.transport = transport
269
- self.drift = self.transport.get_drift()
270
- self.score = self.transport.get_score()
271
-
272
- def __get_sde_diffusion_and_drift(
273
- self,
274
- *,
275
- diffusion_form="SBDM",
276
- diffusion_norm=1.0,
277
- ):
278
-
279
- def diffusion_fn(x, t):
280
- diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
281
- return diffusion
282
-
283
- sde_drift = \
284
- lambda x, t, model, **kwargs: \
285
- self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
286
-
287
- sde_diffusion = diffusion_fn
288
-
289
- return sde_drift, sde_diffusion
290
-
291
- def __get_last_step(
292
- self,
293
- sde_drift,
294
- *,
295
- last_step,
296
- last_step_size,
297
- ):
298
- """Get the last step function of the SDE solver"""
299
-
300
- if last_step is None:
301
- last_step_fn = \
302
- lambda x, t, model, **model_kwargs: \
303
- x
304
- elif last_step == "Mean":
305
- last_step_fn = \
306
- lambda x, t, model, **model_kwargs: \
307
- x + sde_drift(x, t, model, **model_kwargs) * last_step_size
308
- elif last_step == "Tweedie":
309
- alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
310
- sigma = self.transport.path_sampler.compute_sigma_t
311
- last_step_fn = \
312
- lambda x, t, model, **model_kwargs: \
313
- x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,
314
- **model_kwargs)
315
- elif last_step == "Euler":
316
- last_step_fn = \
317
- lambda x, t, model, **model_kwargs: \
318
- x + self.drift(x, t, model, **model_kwargs) * last_step_size
319
- else:
320
- raise NotImplementedError()
321
-
322
- return last_step_fn
323
-
324
- def sample_sde(
325
- self,
326
- *,
327
- sampling_method="Euler",
328
- diffusion_form="SBDM",
329
- diffusion_norm=1.0,
330
- last_step="Mean",
331
- last_step_size=0.04,
332
- num_steps=250,
333
- ):
334
- """returns a sampling function with given SDE settings
335
- Args:
336
- - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
337
- - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
338
- - diffusion_norm: function magnitude of diffusion coefficient; default to 1
339
- - last_step: type of the last step; default to identity
340
- - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
341
- - num_steps: total integration step of SDE
342
- """
343
-
344
- if last_step is None:
345
- last_step_size = 0.0
346
-
347
- sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
348
- diffusion_form=diffusion_form,
349
- diffusion_norm=diffusion_norm,
350
- )
351
-
352
- t0, t1 = self.transport.check_interval(
353
- self.transport.train_eps,
354
- self.transport.sample_eps,
355
- diffusion_form=diffusion_form,
356
- sde=True,
357
- eval=True,
358
- reverse=False,
359
- last_step_size=last_step_size,
360
- )
361
-
362
- _sde = sde(
363
- sde_drift,
364
- sde_diffusion,
365
- t0=t0,
366
- t1=t1,
367
- num_steps=num_steps,
368
- sampler_type=sampling_method
369
- )
370
-
371
- last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
372
-
373
- def _sample(init, model, **model_kwargs):
374
- xs = _sde.sample(init, model, **model_kwargs)
375
- ts = th.ones(init.size(0), device=init.device) * t1
376
- x = last_step_fn(xs[-1], ts, model, **model_kwargs)
377
- xs.append(x)
378
-
379
- assert len(xs) == num_steps, "Samples does not match the number of steps"
380
-
381
- return xs
382
-
383
- return _sample
384
-
385
- def sample_ode(
386
- self,
387
- *,
388
- sampling_method="dopri5",
389
- num_steps=50,
390
- atol=1e-6,
391
- rtol=1e-3,
392
- reverse=False,
393
- ):
394
- """returns a sampling function with given ODE settings
395
- Args:
396
- - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
397
- - num_steps:
398
- - fixed solver (Euler, Heun): the actual number of integration steps performed
399
- - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
400
- - atol: absolute error tolerance for the solver
401
- - rtol: relative error tolerance for the solver
402
- - reverse: whether solving the ODE in reverse (data to noise); default to False
403
- """
404
- if reverse:
405
- drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
406
- else:
407
- drift = self.drift
408
-
409
- t0, t1 = self.transport.check_interval(
410
- self.transport.train_eps,
411
- self.transport.sample_eps,
412
- sde=False,
413
- eval=True,
414
- reverse=reverse,
415
- last_step_size=0.0,
416
- )
417
-
418
- _ode = ode(
419
- drift=drift,
420
- t0=t0,
421
- t1=t1,
422
- sampler_type=sampling_method,
423
- num_steps=num_steps,
424
- atol=atol,
425
- rtol=rtol,
426
- )
427
-
428
- return _ode.sample
429
-
430
- def sample_ode_intermediate(
431
- self,
432
- *,
433
- sampling_method="dopri5",
434
- num_steps=50,
435
- atol=1e-6,
436
- rtol=1e-3,
437
- t=0.5,
438
- reverse=False,
439
- ):
440
- """returns a sampling function with given ODE settings
441
- Args:
442
- - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
443
- - num_steps:
444
- - fixed solver (Euler, Heun): the actual number of integration steps performed
445
- - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
446
- - atol: absolute error tolerance for the solver
447
- - rtol: relative error tolerance for the solver
448
- - reverse: whether solving the ODE in reverse (data to noise); default to False
449
- """
450
- if reverse:
451
- drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
452
- else:
453
- drift = self.drift
454
-
455
- t0, t1 = self.transport.check_interval(
456
- self.transport.train_eps,
457
- self.transport.sample_eps,
458
- sde=False,
459
- eval=True,
460
- reverse=reverse,
461
- last_step_size=0.0,
462
- )
463
-
464
- _ode = ode(
465
- drift=drift,
466
- t0=t,
467
- t1=t1,
468
- sampler_type=sampling_method,
469
- num_steps=num_steps,
470
- atol=atol,
471
- rtol=rtol,
472
- )
473
-
474
- return _ode.sample
475
-
476
- def sample_ode_likelihood(
477
- self,
478
- *,
479
- sampling_method="dopri5",
480
- num_steps=50,
481
- atol=1e-6,
482
- rtol=1e-3,
483
- ):
484
-
485
- """returns a sampling function for calculating likelihood with given ODE settings
486
- Args:
487
- - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
488
- - num_steps:
489
- - fixed solver (Euler, Heun): the actual number of integration steps performed
490
- - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
491
- - atol: absolute error tolerance for the solver
492
- - rtol: relative error tolerance for the solver
493
- """
494
-
495
- def _likelihood_drift(x, t, model, **model_kwargs):
496
- x, _ = x
497
- eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
498
- t = th.ones_like(t) * (1 - t)
499
- with th.enable_grad():
500
- x.requires_grad = True
501
- grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
502
- logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
503
- drift = self.drift(x, t, model, **model_kwargs)
504
- return (-drift, logp_grad)
505
-
506
- t0, t1 = self.transport.check_interval(
507
- self.transport.train_eps,
508
- self.transport.sample_eps,
509
- sde=False,
510
- eval=True,
511
- reverse=False,
512
- last_step_size=0.0,
513
- )
514
-
515
- _ode = ode(
516
- drift=_likelihood_drift,
517
- t0=t0,
518
- t1=t1,
519
- sampler_type=sampling_method,
520
- num_steps=num_steps,
521
- atol=atol,
522
- rtol=rtol,
523
- )
524
-
525
- def _sample_fn(x, model, **model_kwargs):
526
- init_logp = th.zeros(x.size(0)).to(x)
527
- input = (x, init_logp)
528
- drift, delta_logp = _ode.sample(input, model, **model_kwargs)
529
- drift, delta_logp = drift[-1], delta_logp[-1]
530
- prior_logp = self.transport.prior_logp(drift)
531
- logp = prior_logp - delta_logp
532
- return logp, drift
533
-
534
- return _sample_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/models/diffusion/transport/utils.py DELETED
@@ -1,54 +0,0 @@
1
- # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
- # which is licensed under the MIT License.
3
- #
4
- # MIT License
5
- #
6
- # Copyright (c) Meta Platforms, Inc. and affiliates.
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
- import torch as th
27
-
28
- class EasyDict:
29
-
30
- def __init__(self, sub_dict):
31
- for k, v in sub_dict.items():
32
- setattr(self, k, v)
33
-
34
- def __getitem__(self, key):
35
- return getattr(self, key)
36
-
37
- def mean_flat(x):
38
- """
39
- Take the mean over all non-batch dimensions.
40
- """
41
- return th.mean(x, dim=list(range(1, len(x.size()))))
42
-
43
- def log_state(state):
44
- result = []
45
-
46
- sorted_state = dict(sorted(state.items()))
47
- for key, value in sorted_state.items():
48
- # Check if the value is an instance of a class
49
- if "<object" in str(value) or "object at" in str(value):
50
- result.append(f"{key}: [{value.__class__.__name__}]")
51
- else:
52
- result.append(f"{key}: {value}")
53
-
54
- return '\n'.join(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/pipelines.py DELETED
@@ -1,797 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- import copy
23
- import importlib
24
- import inspect
25
- import os
26
- from typing import List, Optional, Union
27
-
28
- import numpy as np
29
- import torch
30
- import trimesh
31
- import yaml
32
- from PIL import Image
33
- from diffusers.utils.torch_utils import randn_tensor
34
- from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
35
- from tqdm import tqdm
36
-
37
- from .models.autoencoders import ShapeVAE
38
- from .models.autoencoders import SurfaceExtractors
39
- from .utils import logger, synchronize_timer, smart_load_model
40
-
41
-
42
- def retrieve_timesteps(
43
- scheduler,
44
- num_inference_steps: Optional[int] = None,
45
- device: Optional[Union[str, torch.device]] = None,
46
- timesteps: Optional[List[int]] = None,
47
- sigmas: Optional[List[float]] = None,
48
- **kwargs,
49
- ):
50
- """
51
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
52
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
53
-
54
- Args:
55
- scheduler (`SchedulerMixin`):
56
- The scheduler to get timesteps from.
57
- num_inference_steps (`int`):
58
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
59
- must be `None`.
60
- device (`str` or `torch.device`, *optional*):
61
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
62
- timesteps (`List[int]`, *optional*):
63
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
64
- `num_inference_steps` and `sigmas` must be `None`.
65
- sigmas (`List[float]`, *optional*):
66
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
67
- `num_inference_steps` and `timesteps` must be `None`.
68
-
69
- Returns:
70
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
71
- second element is the number of inference steps.
72
- """
73
- if timesteps is not None and sigmas is not None:
74
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
75
- if timesteps is not None:
76
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
77
- if not accepts_timesteps:
78
- raise ValueError(
79
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
80
- f" timestep schedules. Please check whether you are using the correct scheduler."
81
- )
82
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
83
- timesteps = scheduler.timesteps
84
- num_inference_steps = len(timesteps)
85
- elif sigmas is not None:
86
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
87
- if not accept_sigmas:
88
- raise ValueError(
89
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
90
- f" sigmas schedules. Please check whether you are using the correct scheduler."
91
- )
92
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
93
- timesteps = scheduler.timesteps
94
- num_inference_steps = len(timesteps)
95
- else:
96
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
97
- timesteps = scheduler.timesteps
98
- return timesteps, num_inference_steps
99
-
100
-
101
- @synchronize_timer('Export to trimesh')
102
- def export_to_trimesh(mesh_output):
103
- if isinstance(mesh_output, list):
104
- outputs = []
105
- for mesh in mesh_output:
106
- if mesh is None:
107
- outputs.append(None)
108
- else:
109
- mesh.mesh_f = mesh.mesh_f[:, ::-1]
110
- mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
111
- outputs.append(mesh_output)
112
- return outputs
113
- else:
114
- mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
115
- mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
116
- return mesh_output
117
-
118
-
119
- def get_obj_from_str(string, reload=False):
120
- module, cls = string.rsplit(".", 1)
121
- if reload:
122
- module_imp = importlib.import_module(module)
123
- importlib.reload(module_imp)
124
- return getattr(importlib.import_module(module, package=None), cls)
125
-
126
-
127
- def instantiate_from_config(config, **kwargs):
128
- if "target" not in config:
129
- raise KeyError("Expected key `target` to instantiate.")
130
- cls = get_obj_from_str(config["target"])
131
- params = config.get("params", dict())
132
- kwargs.update(params)
133
- instance = cls(**kwargs)
134
- return instance
135
-
136
-
137
- class DiTPipeline:
138
- model_cpu_offload_seq = "conditioner->model->vae"
139
- _exclude_from_cpu_offload = []
140
-
141
- @classmethod
142
- @synchronize_timer('DiTPipeline Model Loading')
143
- def from_single_file(
144
- cls,
145
- ckpt_path,
146
- config_path,
147
- device='cuda',
148
- dtype=torch.float16,
149
- use_safetensors=None,
150
- **kwargs,
151
- ):
152
- # load config
153
- with open(config_path, 'r') as f:
154
- config = yaml.safe_load(f)
155
-
156
- # load ckpt
157
- if use_safetensors:
158
- ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
159
- if not os.path.exists(ckpt_path):
160
- raise FileNotFoundError(f"Model file {ckpt_path} not found")
161
- logger.info(f"Loading model from {ckpt_path}")
162
-
163
- if use_safetensors:
164
- # parse safetensors
165
- import safetensors.torch
166
- safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
167
- ckpt = {}
168
- for key, value in safetensors_ckpt.items():
169
- model_name = key.split('.')[0]
170
- new_key = key[len(model_name) + 1:]
171
- if model_name not in ckpt:
172
- ckpt[model_name] = {}
173
- ckpt[model_name][new_key] = value
174
- else:
175
- ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
176
- # load model
177
- model = instantiate_from_config(config['model'])
178
- model.load_state_dict(ckpt['model'])
179
- vae = instantiate_from_config(config['vae'])
180
- vae.load_state_dict(ckpt['vae'], strict=False)
181
- conditioner = instantiate_from_config(config['conditioner'])
182
- if 'conditioner' in ckpt:
183
- conditioner.load_state_dict(ckpt['conditioner'])
184
- image_processor = instantiate_from_config(config['image_processor'])
185
- scheduler = instantiate_from_config(config['scheduler'])
186
-
187
- model_kwargs = dict(
188
- vae=vae,
189
- model=model,
190
- scheduler=scheduler,
191
- conditioner=conditioner,
192
- image_processor=image_processor,
193
- device=device,
194
- dtype=dtype,
195
- )
196
- model_kwargs.update(kwargs)
197
-
198
- return cls(
199
- **model_kwargs
200
- )
201
-
202
- @classmethod
203
- def from_pretrained(
204
- cls,
205
- model_path,
206
- device='cuda',
207
- dtype=torch.float16,
208
- use_safetensors=False,
209
- variant='fp16',
210
- subfolder='hunyuan3d-dit-v2-1',
211
- **kwargs,
212
- ):
213
- kwargs['from_pretrained_kwargs'] = dict(
214
- model_path=model_path,
215
- subfolder=subfolder,
216
- use_safetensors=use_safetensors,
217
- variant=variant,
218
- dtype=dtype,
219
- device=device,
220
- )
221
- config_path, ckpt_path = smart_load_model(
222
- model_path,
223
- subfolder=subfolder,
224
- use_safetensors=use_safetensors,
225
- variant=variant
226
- )
227
- return cls.from_single_file(
228
- ckpt_path,
229
- config_path,
230
- device=device,
231
- dtype=dtype,
232
- use_safetensors=use_safetensors,
233
- **kwargs
234
- )
235
-
236
- def __init__(
237
- self,
238
- vae,
239
- model,
240
- scheduler,
241
- conditioner,
242
- image_processor,
243
- device='cuda',
244
- dtype=torch.float16,
245
- ref_model=None,
246
- **kwargs
247
- ):
248
- self.vae = vae
249
- self.model = model
250
- self.ref_model = ref_model
251
- self.scheduler = scheduler
252
- self.conditioner = conditioner
253
- self.image_processor = image_processor
254
- self.kwargs = kwargs
255
-
256
- self.components = {
257
- "vae": vae,
258
- "model": model,
259
- "scheduler": scheduler,
260
- "conditioner": conditioner,
261
- "image_processor": image_processor,
262
- }
263
- if ref_model is not None:
264
- self.components["ref_model"] = ref_model
265
-
266
- self.to(device, dtype)
267
-
268
- def compile(self):
269
- self.vae = torch.compile(self.vae)
270
- self.model = torch.compile(self.model)
271
- self.conditioner = torch.compile(self.conditioner)
272
-
273
- def enable_flashvdm(
274
- self,
275
- enabled: bool = True,
276
- adaptive_kv_selection=True,
277
- topk_mode='mean',
278
- mc_algo='mc',
279
- replace_vae=True,
280
- ):
281
- if enabled:
282
- self.vae.enable_flashvdm_decoder(
283
- enabled=enabled,
284
- adaptive_kv_selection=adaptive_kv_selection,
285
- topk_mode=topk_mode,
286
- mc_algo=mc_algo
287
- )
288
- else:
289
- self.vae.enable_flashvdm_decoder(enabled=False)
290
-
291
- def to(self, device=None, dtype=None):
292
- if dtype is not None:
293
- self.dtype = dtype
294
- self.vae.to(dtype=dtype)
295
- self.model.to(dtype=dtype)
296
- self.conditioner.to(dtype=dtype)
297
- if device is not None:
298
- self.device = torch.device(device)
299
- self.vae.to(device)
300
- self.model.to(device)
301
- self.conditioner.to(device)
302
-
303
- @property
304
- def _execution_device(self):
305
- r"""
306
- Returns the device on which the pipeline's models will be executed. After calling
307
- [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
308
- Accelerate's module hooks.
309
- """
310
- for name, model in self.components.items():
311
- if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
312
- continue
313
-
314
- if not hasattr(model, "_hf_hook"):
315
- return self.device
316
- for module in model.modules():
317
- if (
318
- hasattr(module, "_hf_hook")
319
- and hasattr(module._hf_hook, "execution_device")
320
- and module._hf_hook.execution_device is not None
321
- ):
322
- return torch.device(module._hf_hook.execution_device)
323
- return self.device
324
-
325
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
326
- r"""
327
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
328
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
329
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
330
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
331
-
332
- Arguments:
333
- gpu_id (`int`, *optional*):
334
- The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
335
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
336
- The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
337
- default to "cuda".
338
- """
339
- if self.model_cpu_offload_seq is None:
340
- raise ValueError(
341
- "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
342
- )
343
-
344
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
345
- from accelerate import cpu_offload_with_hook
346
- else:
347
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
348
-
349
- torch_device = torch.device(device)
350
- device_index = torch_device.index
351
-
352
- if gpu_id is not None and device_index is not None:
353
- raise ValueError(
354
- f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
355
- f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of "
356
- f"the device: `device`={torch_device.type}"
357
- )
358
-
359
- # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`)
360
- # or default to previously set id or default to 0
361
- self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
362
-
363
- device_type = torch_device.type
364
- device = torch.device(f"{device_type}:{self._offload_gpu_id}")
365
-
366
- if self.device.type != "cpu":
367
- self.to("cpu")
368
- device_mod = getattr(torch, self.device.type, None)
369
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
370
- device_mod.empty_cache()
371
- # otherwise we don't see the memory savings (but they probably exist)
372
-
373
- all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
374
-
375
- self._all_hooks = []
376
- hook = None
377
- for model_str in self.model_cpu_offload_seq.split("->"):
378
- model = all_model_components.pop(model_str, None)
379
- if not isinstance(model, torch.nn.Module):
380
- continue
381
-
382
- _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
383
- self._all_hooks.append(hook)
384
-
385
- # CPU offload models that are not in the seq chain unless they are explicitly excluded
386
- # these models will stay on CPU until maybe_free_model_hooks is called
387
- # some models cannot be in the seq chain because they are iteratively called,
388
- # such as controlnet
389
- for name, model in all_model_components.items():
390
- if not isinstance(model, torch.nn.Module):
391
- continue
392
-
393
- if name in self._exclude_from_cpu_offload:
394
- model.to(device)
395
- else:
396
- _, hook = cpu_offload_with_hook(model, device)
397
- self._all_hooks.append(hook)
398
-
399
- def maybe_free_model_hooks(self):
400
- r"""
401
- Function that offloads all components, removes all model hooks that were added when using
402
- `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
403
- is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
404
- functions correctly when applying enable_model_cpu_offload.
405
- """
406
- if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
407
- # `enable_model_cpu_offload` has not be called, so silently do nothing
408
- return
409
-
410
- for hook in self._all_hooks:
411
- # offload model and remove hook from model
412
- hook.offload()
413
- hook.remove()
414
-
415
- # make sure the model is in the same state as before calling it
416
- self.enable_model_cpu_offload()
417
-
418
- @synchronize_timer('Encode cond')
419
- def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):
420
- bsz = image.shape[0]
421
- cond = self.conditioner(image=image, **additional_cond_inputs) # cond['main'].shape
422
-
423
- if do_classifier_free_guidance:
424
- cond_token_num = cond["main"].shape[1]
425
- additional_cond_inputs["num_tokens"] = cond_token_num
426
- un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
427
-
428
- if dual_guidance:
429
- un_cond_drop_main = copy.deepcopy(un_cond)
430
- un_cond_drop_main['additional'] = cond['additional']
431
-
432
- def cat_recursive(a, b, c):
433
- if isinstance(a, torch.Tensor):
434
- return torch.cat([a, b, c], dim=0).to(self.dtype)
435
- out = {}
436
- for k in a.keys():
437
- out[k] = cat_recursive(a[k], b[k], c[k])
438
- return out
439
-
440
- cond = cat_recursive(cond, un_cond_drop_main, un_cond)
441
- else:
442
- def cat_recursive(a, b):
443
- if isinstance(a, torch.Tensor):
444
- return torch.cat([a, b], dim=0).to(self.dtype)
445
- out = {}
446
- for k in a.keys():
447
- out[k] = cat_recursive(a[k], b[k])
448
- return out
449
-
450
- cond = cat_recursive(cond, un_cond)
451
- return cond
452
-
453
- def prepare_extra_step_kwargs(self, generator, eta):
454
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
455
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
456
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
457
- # and should be between [0, 1]
458
-
459
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
460
- extra_step_kwargs = {}
461
- if accepts_eta:
462
- extra_step_kwargs["eta"] = eta
463
-
464
- # check if the scheduler accepts generator
465
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
466
- if accepts_generator:
467
- extra_step_kwargs["generator"] = generator
468
- return extra_step_kwargs
469
-
470
- def prepare_latents(self, batch_size, dtype, device, generator, latents=None, shape=None):
471
- if shape is None:
472
- shape = (batch_size, *self.vae.latent_shape)
473
- else:
474
- shape = (batch_size, *shape)
475
- if isinstance(generator, list) and len(generator) != batch_size:
476
- raise ValueError(
477
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
478
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
479
- )
480
-
481
- if latents is None:
482
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
483
- else:
484
- latents = latents.to(device)
485
-
486
- # scale the initial noise by the standard deviation required by the scheduler
487
- latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
488
- return latents
489
-
490
- def prepare_image(self, image, mask=None) -> dict:
491
- if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):
492
- outputs = {
493
- 'image': image,
494
- 'mask': mask
495
- }
496
- return outputs
497
-
498
- if isinstance(image, str) and not os.path.exists(image):
499
- raise FileNotFoundError(f"Couldn't find image at path {image}")
500
-
501
- if not isinstance(image, list):
502
- image = [image]
503
-
504
- outputs = []
505
- for img in image:
506
- output = self.image_processor(img) # output['image'].shape
507
- outputs.append(output)
508
-
509
- cond_input = {k: [] for k in outputs[0].keys()}
510
- for output in outputs:
511
- for key, value in output.items():
512
- cond_input[key].append(value)
513
- for key, value in cond_input.items():
514
- if isinstance(value[0], torch.Tensor):
515
- cond_input[key] = torch.cat(value, dim=0)
516
-
517
- return cond_input
518
-
519
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
520
- """
521
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
522
-
523
- Args:
524
- timesteps (`torch.Tensor`):
525
- generate embedding vectors at these timesteps
526
- embedding_dim (`int`, *optional*, defaults to 512):
527
- dimension of the embeddings to generate
528
- dtype:
529
- data type of the generated embeddings
530
-
531
- Returns:
532
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
533
- """
534
- assert len(w.shape) == 1
535
- w = w * 1000.0
536
-
537
- half_dim = embedding_dim // 2
538
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
539
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
540
- emb = w.to(dtype)[:, None] * emb[None, :]
541
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
542
- if embedding_dim % 2 == 1: # zero pad
543
- emb = torch.nn.functional.pad(emb, (0, 1))
544
- assert emb.shape == (w.shape[0], embedding_dim)
545
- return emb
546
-
547
- def set_surface_extractor(self, mc_algo):
548
- if mc_algo is None:
549
- return
550
- logger.info('The parameters `mc_algo` is deprecated, and will be removed in future versions.\n'
551
- 'Please use: \n'
552
- 'from hy3dshape.models.autoencoders import SurfaceExtractors\n'
553
- 'pipeline.vae.surface_extractor = SurfaceExtractors[mc_algo]() instead\n')
554
- if mc_algo not in SurfaceExtractors.keys():
555
- raise ValueError(f"Unknown mc_algo {mc_algo}")
556
- self.vae.surface_extractor = SurfaceExtractors[mc_algo]()
557
-
558
- @torch.no_grad()
559
- def __call__(
560
- self,
561
- image: Union[str, List[str], Image.Image] = None,
562
- num_inference_steps: int = 50,
563
- timesteps: List[int] = None,
564
- sigmas: List[float] = None,
565
- eta: float = 0.0,
566
- guidance_scale: float = 7.5,
567
- dual_guidance_scale: float = 10.5,
568
- dual_guidance: bool = True,
569
- generator=None,
570
- box_v=1.01,
571
- octree_resolution=384,
572
- mc_level=-1 / 512,
573
- num_chunks=8000,
574
- mc_algo=None,
575
- output_type: Optional[str] = "trimesh",
576
- enable_pbar=True,
577
- **kwargs,
578
- ) -> List[List[trimesh.Trimesh]]:
579
- callback = kwargs.pop("callback", None)
580
- callback_steps = kwargs.pop("callback_steps", None)
581
-
582
- self.set_surface_extractor(mc_algo)
583
-
584
- device = self.device
585
- dtype = self.dtype
586
- do_classifier_free_guidance = guidance_scale >= 0 and \
587
- getattr(self.model, 'guidance_cond_proj_dim', None) is None
588
- dual_guidance = dual_guidance_scale >= 0 and dual_guidance
589
-
590
- if isinstance(image, torch.Tensor):
591
- pass
592
- else:
593
- cond_inputs = self.prepare_image(image)
594
- image = cond_inputs.pop('image')
595
-
596
- cond = self.encode_cond(
597
- image=image,
598
- additional_cond_inputs=cond_inputs,
599
- do_classifier_free_guidance=do_classifier_free_guidance,
600
- dual_guidance=False,
601
- )
602
- batch_size = image.shape[0]
603
-
604
- t_dtype = torch.long
605
- timesteps, num_inference_steps = retrieve_timesteps(
606
- self.scheduler, num_inference_steps, device, timesteps, sigmas)
607
-
608
- latents = self.prepare_latents(batch_size, dtype, device, generator)
609
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
610
-
611
- guidance_cond = None
612
- if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
613
- logger.info('Using lcm guidance scale')
614
- guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
615
- guidance_cond = self.get_guidance_scale_embedding(
616
- guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
617
- ).to(device=device, dtype=latents.dtype)
618
- with synchronize_timer('Diffusion Sampling'):
619
- for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
620
- # expand the latents if we are doing classifier free guidance
621
- if do_classifier_free_guidance:
622
- latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
623
- else:
624
- latent_model_input = latents
625
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
626
-
627
- # predict the noise residual
628
- timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
629
- timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
630
- noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
631
-
632
- # no drop, drop clip, all drop
633
- if do_classifier_free_guidance:
634
- if dual_guidance:
635
- noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
636
- noise_pred = (
637
- noise_pred_uncond
638
- + guidance_scale * (noise_pred_clip - noise_pred_dino)
639
- + dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
640
- )
641
- else:
642
- noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
643
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
644
-
645
- # compute the previous noisy sample x_t -> x_t-1
646
- outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
647
- latents = outputs.prev_sample
648
-
649
- if callback is not None and i % callback_steps == 0:
650
- step_idx = i // getattr(self.scheduler, "order", 1)
651
- callback(step_idx, t, outputs)
652
-
653
- return self._export(
654
- latents,
655
- output_type,
656
- box_v, mc_level, num_chunks, octree_resolution, mc_algo,
657
- )
658
-
659
- def _export(
660
- self,
661
- latents,
662
- output_type='trimesh',
663
- box_v=1.01,
664
- mc_level=0.0,
665
- num_chunks=20000,
666
- octree_resolution=256,
667
- mc_algo='mc',
668
- enable_pbar=True
669
- ):
670
- if not output_type == "latent":
671
- latents = 1. / self.vae.scale_factor * latents
672
- latents = self.vae(latents)
673
- outputs, _ = self.vae.latents2mesh(
674
- latents,
675
- bounds=box_v,
676
- mc_level=mc_level,
677
- num_chunks=num_chunks,
678
- octree_resolution=octree_resolution,
679
- mc_algo=mc_algo,
680
- enable_pbar=enable_pbar,
681
- )
682
- else:
683
- outputs = latents
684
-
685
- if output_type == 'trimesh':
686
- outputs = export_to_trimesh(outputs)
687
-
688
- return outputs
689
-
690
-
691
- class UltraShapePipeline(DiTPipeline):
692
-
693
- @torch.inference_mode()
694
- def __call__(
695
- self,
696
- image: Union[str, List[str], Image.Image, dict, List[dict], torch.Tensor] = None,
697
- voxel_cond: torch.Tensor = None,
698
- num_inference_steps: int = 50,
699
- timesteps: List[int] = None,
700
- sigmas: List[float] = None,
701
- eta: float = 0.0,
702
- guidance_scale: float = 5.0,
703
- generator=None,
704
- box_v=1.01,
705
- octree_resolution=384,
706
- mc_level=0.0,
707
- mc_algo=None,
708
- num_chunks=8000,
709
- output_type: Optional[str] = "trimesh",
710
- enable_pbar=True,
711
- mask = None,
712
- **kwargs,
713
- ) -> List[List[trimesh.Trimesh]]:
714
- callback = kwargs.pop("callback", None)
715
- callback_steps = kwargs.pop("callback_steps", None)
716
-
717
- self.set_surface_extractor(mc_algo)
718
-
719
- device = self.device
720
- dtype = self.dtype
721
- do_classifier_free_guidance = guidance_scale >= 0 and not (
722
- hasattr(self.model, 'guidance_embed') and
723
- self.model.guidance_embed is True
724
- )
725
-
726
- # print('image', type(image), 'mask', type(mask))
727
- cond_inputs = self.prepare_image(image, mask)
728
- image = cond_inputs.pop('image')
729
- cond = self.encode_cond(
730
- image=image,
731
- additional_cond_inputs=cond_inputs,
732
- do_classifier_free_guidance=do_classifier_free_guidance,
733
- dual_guidance=False,
734
- )
735
-
736
- batch_size = image.shape[0]
737
-
738
- # 5. Prepare timesteps
739
- # NOTE: this is slightly different from common usage, we start from 0.
740
- sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
741
- timesteps, num_inference_steps = retrieve_timesteps(
742
- self.scheduler,
743
- num_inference_steps,
744
- device,
745
- sigmas=sigmas,
746
- )
747
- latents_shape = None
748
- if voxel_cond is not None:
749
- # voxel_cond: [B, N, 3] -> [N, 3] if batched? No, it's [B, N, 3] usually
750
- # The encoder expects [B, N, 3]
751
- num_tokens = voxel_cond.shape[1]
752
- latents_shape = (num_tokens, self.vae.latent_shape[-1])
753
-
754
- latents = self.prepare_latents(batch_size, dtype, device, generator, shape=latents_shape)
755
-
756
- guidance = None
757
- if hasattr(self.model, 'guidance_embed') and \
758
- self.model.guidance_embed is True:
759
- guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)
760
- # logger.info(f'Using guidance embed with scale {guidance_scale}')
761
- if do_classifier_free_guidance and voxel_cond is not None:
762
- voxel_cond = torch.cat([voxel_cond] * 2)
763
- with synchronize_timer('Diffusion Sampling'):
764
- for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")):
765
- # expand the latents if we are doing classifier free guidance
766
- if do_classifier_free_guidance:
767
- latent_model_input = torch.cat([latents] * 2)
768
- else:
769
- latent_model_input = latents
770
-
771
- # NOTE: we assume model get timesteps ranged from 0 to 1
772
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype).to(latent_model_input.device)
773
- timestep = timestep / self.scheduler.config.num_train_timesteps
774
- if voxel_cond is None:
775
- noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
776
- else:
777
- noise_pred = self.model(latent_model_input, timestep, cond,
778
- guidance=guidance, voxel_cond=voxel_cond)
779
-
780
- if do_classifier_free_guidance:
781
- noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
782
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
783
-
784
- # compute the previous noisy sample x_t -> x_t-1
785
- outputs = self.scheduler.step(noise_pred, t, latents)
786
- latents = outputs.prev_sample
787
-
788
- if callback is not None and i % callback_steps == 0:
789
- step_idx = i // getattr(self.scheduler, "order", 1)
790
- callback(step_idx, t, outputs)
791
-
792
- return self._export(
793
- latents,
794
- output_type,
795
- box_v, mc_level, num_chunks, octree_resolution, mc_algo,
796
- enable_pbar=enable_pbar,
797
- ), latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/postprocessors.py DELETED
@@ -1,209 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- import os
23
- import tempfile
24
- from typing import Union
25
-
26
- import numpy as np
27
- import pymeshlab
28
- import torch
29
- import trimesh
30
-
31
- from .models.autoencoders import Latent2MeshOutput
32
- from .utils import synchronize_timer
33
-
34
-
35
- def load_mesh(path):
36
- if path.endswith(".glb"):
37
- mesh = trimesh.load(path)
38
- else:
39
- mesh = pymeshlab.MeshSet()
40
- mesh.load_new_mesh(path)
41
- return mesh
42
-
43
-
44
- def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
45
- if max_facenum > mesh.current_mesh().face_number():
46
- return mesh
47
-
48
- mesh.apply_filter(
49
- "meshing_decimation_quadric_edge_collapse",
50
- targetfacenum=max_facenum,
51
- qualitythr=1.0,
52
- preserveboundary=True,
53
- boundaryweight=3,
54
- preservenormal=True,
55
- preservetopology=True,
56
- autoclean=True
57
- )
58
- return mesh
59
-
60
-
61
- def remove_floater(mesh: pymeshlab.MeshSet):
62
- mesh.apply_filter("compute_selection_by_small_disconnected_components_per_face",
63
- nbfaceratio=0.005)
64
- mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False)
65
- mesh.apply_filter("meshing_remove_selected_vertices_and_faces")
66
- return mesh
67
-
68
-
69
- def pymeshlab2trimesh(mesh: pymeshlab.MeshSet):
70
- with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
71
- mesh.save_current_mesh(temp_file.name)
72
- mesh = trimesh.load(temp_file.name)
73
- # 检查加载的对象类型
74
- if isinstance(mesh, trimesh.Scene):
75
- combined_mesh = trimesh.Trimesh()
76
- # 如果是Scene,遍历所有的geometry并合并
77
- for geom in mesh.geometry.values():
78
- combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
79
- mesh = combined_mesh
80
- return mesh
81
-
82
-
83
- def trimesh2pymeshlab(mesh: trimesh.Trimesh):
84
- with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
85
- if isinstance(mesh, trimesh.scene.Scene):
86
- for idx, obj in enumerate(mesh.geometry.values()):
87
- if idx == 0:
88
- temp_mesh = obj
89
- else:
90
- temp_mesh = temp_mesh + obj
91
- mesh = temp_mesh
92
- mesh.export(temp_file.name)
93
- mesh = pymeshlab.MeshSet()
94
- mesh.load_new_mesh(temp_file.name)
95
- return mesh
96
-
97
-
98
- def export_mesh(input, output):
99
- if isinstance(input, pymeshlab.MeshSet):
100
- mesh = output
101
- elif isinstance(input, Latent2MeshOutput):
102
- output = Latent2MeshOutput()
103
- output.mesh_v = output.current_mesh().vertex_matrix()
104
- output.mesh_f = output.current_mesh().face_matrix()
105
- mesh = output
106
- else:
107
- mesh = pymeshlab2trimesh(output)
108
- return mesh
109
-
110
-
111
- def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
112
- if isinstance(mesh, str):
113
- mesh = load_mesh(mesh)
114
- elif isinstance(mesh, Latent2MeshOutput):
115
- mesh = pymeshlab.MeshSet()
116
- mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
117
- mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
118
-
119
- if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
120
- mesh = trimesh2pymeshlab(mesh)
121
-
122
- return mesh
123
-
124
-
125
- class FaceReducer:
126
- @synchronize_timer('FaceReducer')
127
- def __call__(
128
- self,
129
- mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
130
- max_facenum: int = 40000
131
- ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
132
- ms = import_mesh(mesh)
133
- ms = reduce_face(ms, max_facenum=max_facenum)
134
- mesh = export_mesh(mesh, ms)
135
- return mesh
136
-
137
-
138
- class FloaterRemover:
139
- @synchronize_timer('FloaterRemover')
140
- def __call__(
141
- self,
142
- mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
143
- ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
144
- ms = import_mesh(mesh)
145
- ms = remove_floater(ms)
146
- mesh = export_mesh(mesh, ms)
147
- return mesh
148
-
149
-
150
- class DegenerateFaceRemover:
151
- @synchronize_timer('DegenerateFaceRemover')
152
- def __call__(
153
- self,
154
- mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
155
- ) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
156
- ms = import_mesh(mesh)
157
-
158
- with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
159
- ms.save_current_mesh(temp_file.name)
160
- ms = pymeshlab.MeshSet()
161
- ms.load_new_mesh(temp_file.name)
162
-
163
- mesh = export_mesh(mesh, ms)
164
- return mesh
165
-
166
-
167
- def mesh_normalize(mesh):
168
- """
169
- Normalize mesh vertices to sphere
170
- """
171
- scale_factor = 1.2
172
- vtx_pos = np.asarray(mesh.vertices)
173
- max_bb = (vtx_pos - 0).max(0)[0]
174
- min_bb = (vtx_pos - 0).min(0)[0]
175
-
176
- center = (max_bb + min_bb) / 2
177
-
178
- scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
179
-
180
- vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
181
- mesh.vertices = vtx_pos
182
-
183
- return mesh
184
-
185
-
186
- class MeshSimplifier:
187
- def __init__(self, executable: str = None):
188
- if executable is None:
189
- CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
190
- executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
191
- self.executable = executable
192
-
193
- @synchronize_timer('MeshSimplifier')
194
- def __call__(
195
- self,
196
- mesh: Union[trimesh.Trimesh],
197
- ) -> Union[trimesh.Trimesh]:
198
- with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
199
- with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
200
- mesh.export(temp_input.name)
201
- os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
202
- ms = trimesh.load(temp_output.name, process=False)
203
- if isinstance(ms, trimesh.Scene):
204
- combined_mesh = trimesh.Trimesh()
205
- for geom in ms.geometry.values():
206
- combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
207
- ms = combined_mesh
208
- ms = mesh_normalize(ms)
209
- return ms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/preprocessors.py DELETED
@@ -1,167 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import cv2
16
- import numpy as np
17
- import torch
18
- from PIL import Image
19
- from einops import repeat, rearrange
20
-
21
-
22
- def array_to_tensor(np_array):
23
- image_pt = torch.tensor(np_array).float()
24
- image_pt = image_pt / 255 * 2 - 1
25
- image_pt = rearrange(image_pt, "h w c -> c h w")
26
- image_pts = repeat(image_pt, "c h w -> b c h w", b=1)
27
- return image_pts
28
-
29
-
30
- class ImageProcessorV2:
31
- def __init__(self, size=512, border_ratio=None):
32
- self.size = size
33
- self.border_ratio = border_ratio
34
-
35
- @staticmethod
36
- def recenter(image, border_ratio: float = 0.2):
37
- """ recenter an image to leave some empty space at the image border.
38
-
39
- Args:
40
- image (ndarray): input image, float/uint8 [H, W, 3/4]
41
- mask (ndarray): alpha mask, bool [H, W]
42
- border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
43
-
44
- Returns:
45
- ndarray: output image, float/uint8 [H, W, 3/4]
46
- """
47
-
48
- if image.shape[-1] == 4:
49
- mask = image[..., 3]
50
- else:
51
- mask = np.ones_like(image[..., 0:1]) * 255
52
- image = np.concatenate([image, mask], axis=-1)
53
- mask = mask[..., 0]
54
-
55
- H, W, C = image.shape
56
-
57
- size = max(H, W)
58
- result = np.zeros((size, size, C), dtype=np.uint8)
59
-
60
- coords = np.nonzero(mask)
61
- x_min, x_max = coords[0].min(), coords[0].max()
62
- y_min, y_max = coords[1].min(), coords[1].max()
63
- h = x_max - x_min
64
- w = y_max - y_min
65
- if h == 0 or w == 0:
66
- raise ValueError('input image is empty')
67
- desired_size = int(size * (1 - border_ratio))
68
- scale = desired_size / max(h, w)
69
- h2 = int(h * scale)
70
- w2 = int(w * scale)
71
- x2_min = (size - h2) // 2
72
- x2_max = x2_min + h2
73
-
74
- y2_min = (size - w2) // 2
75
- y2_max = y2_min + w2
76
-
77
- result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2),
78
- interpolation=cv2.INTER_AREA)
79
-
80
- bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
81
-
82
- mask = result[..., 3:].astype(np.float32) / 255
83
- result = result[..., :3] * mask + bg * (1 - mask)
84
-
85
- mask = mask * 255
86
- result = result.clip(0, 255).astype(np.uint8)
87
- mask = mask.clip(0, 255).astype(np.uint8)
88
- return result, mask
89
-
90
- def load_image(self, image, border_ratio=0.15, to_tensor=True):
91
- if isinstance(image, str):
92
- image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
93
- image, mask = self.recenter(image, border_ratio=border_ratio)
94
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
95
- elif isinstance(image, Image.Image):
96
- image = image.convert("RGBA")
97
- image = np.asarray(image)
98
- image, mask = self.recenter(image, border_ratio=border_ratio)
99
-
100
- image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
101
- mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
102
- mask = mask[..., np.newaxis]
103
-
104
- if to_tensor:
105
- image = array_to_tensor(image)
106
- mask = array_to_tensor(mask)
107
- return image, mask
108
-
109
- def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
110
- if self.border_ratio is not None:
111
- border_ratio = self.border_ratio
112
- image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
113
- outputs = {
114
- 'image': image,
115
- 'mask': mask
116
- }
117
- return outputs
118
-
119
-
120
- class MVImageProcessorV2(ImageProcessorV2):
121
- """
122
- view order: front, front clockwise 90, back, front clockwise 270
123
- """
124
- return_view_idx = True
125
-
126
- def __init__(self, size=512, border_ratio=None):
127
- super().__init__(size, border_ratio)
128
- self.view2idx = {
129
- 'front': 0,
130
- 'left': 1,
131
- 'back': 2,
132
- 'right': 3
133
- }
134
-
135
- def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
136
- if self.border_ratio is not None:
137
- border_ratio = self.border_ratio
138
-
139
- images = []
140
- masks = []
141
- view_idxs = []
142
- for idx, (view_tag, image) in enumerate(image_dict.items()):
143
- view_idxs.append(self.view2idx[view_tag])
144
- image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
145
- images.append(image)
146
- masks.append(mask)
147
-
148
- zipped_lists = zip(view_idxs, images, masks)
149
- sorted_zipped_lists = sorted(zipped_lists)
150
- view_idxs, images, masks = zip(*sorted_zipped_lists)
151
-
152
- image = torch.cat(images, 0).unsqueeze(0)
153
- mask = torch.cat(masks, 0).unsqueeze(0)
154
- outputs = {
155
- 'image': image,
156
- 'mask': mask,
157
- 'view_idxs': view_idxs
158
- }
159
- return outputs
160
-
161
-
162
- IMAGE_PROCESSORS = {
163
- "v2": ImageProcessorV2,
164
- 'mv_v2': MVImageProcessorV2,
165
- }
166
-
167
- DEFAULT_IMAGEPROCESSOR = 'v2'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/rembg.py DELETED
@@ -1,32 +0,0 @@
1
- # ==============================================================================
2
- # Original work Copyright (c) 2025 Tencent.
3
- # Modified work Copyright (c) 2025 UltraShape Team.
4
- #
5
- # Modified by UltraShape on 2025.12.25
6
- # ==============================================================================
7
-
8
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
9
- # except for the third-party components listed below.
10
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
11
- # in the repsective licenses of these third-party components.
12
- # Users must comply with all terms and conditions of original licenses of these third-party
13
- # components and must ensure that the usage of the third party components adheres to
14
- # all relevant laws and regulations.
15
-
16
- # For avoidance of doubts, Hunyuan 3D means the large language models and
17
- # their software and algorithms, including trained model weights, parameters (including
18
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
19
- # fine-tuning enabling code and other elements of the foregoing made publicly available
20
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
21
-
22
- from PIL import Image
23
- from rembg import remove, new_session
24
-
25
-
26
- class BackgroundRemover():
27
- def __init__(self):
28
- self.session = new_session()
29
-
30
- def __call__(self, image: Image.Image):
31
- output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])
32
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/schedulers.py DELETED
@@ -1,480 +0,0 @@
1
- # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
16
- # except for the third-party components listed below.
17
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
18
- # in the repsective licenses of these third-party components.
19
- # Users must comply with all terms and conditions of original licenses of these third-party
20
- # components and must ensure that the usage of the third party components adheres to
21
- # all relevant laws and regulations.
22
-
23
- # For avoidance of doubts, Hunyuan 3D means the large language models and
24
- # their software and algorithms, including trained model weights, parameters (including
25
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
26
- # fine-tuning enabling code and other elements of the foregoing made publicly available
27
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
28
-
29
- import math
30
- from dataclasses import dataclass
31
- from typing import List, Optional, Tuple, Union
32
-
33
- import numpy as np
34
- import torch
35
- from diffusers.configuration_utils import ConfigMixin, register_to_config
36
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
37
- from diffusers.utils import BaseOutput, logging
38
-
39
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
-
41
-
42
- @dataclass
43
- class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
44
- """
45
- Output class for the scheduler's `step` function output.
46
-
47
- Args:
48
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
49
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
50
- denoising loop.
51
- """
52
-
53
- prev_sample: torch.FloatTensor
54
-
55
-
56
- class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
57
- """
58
- NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
59
-
60
- Euler scheduler.
61
-
62
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
63
- methods the library implements for all schedulers such as loading and saving.
64
-
65
- Args:
66
- num_train_timesteps (`int`, defaults to 1000):
67
- The number of diffusion steps to train the model.
68
- timestep_spacing (`str`, defaults to `"linspace"`):
69
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
70
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
71
- shift (`float`, defaults to 1.0):
72
- The shift value for the timestep schedule.
73
- """
74
-
75
- _compatibles = []
76
- order = 1
77
-
78
- @register_to_config
79
- def __init__(
80
- self,
81
- num_train_timesteps: int = 1000,
82
- shift: float = 1.0,
83
- use_dynamic_shifting=False,
84
- ):
85
- timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()
86
- timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
87
-
88
- sigmas = timesteps / num_train_timesteps
89
- if not use_dynamic_shifting:
90
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
91
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
92
-
93
- self.timesteps = sigmas * num_train_timesteps
94
-
95
- self._step_index = None
96
- self._begin_index = None
97
-
98
- self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
99
- self.sigma_min = self.sigmas[-1].item()
100
- self.sigma_max = self.sigmas[0].item()
101
-
102
- @property
103
- def step_index(self):
104
- """
105
- The index counter for current timestep. It will increase 1 after each scheduler step.
106
- """
107
- return self._step_index
108
-
109
- @property
110
- def begin_index(self):
111
- """
112
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
113
- """
114
- return self._begin_index
115
-
116
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
117
- def set_begin_index(self, begin_index: int = 0):
118
- """
119
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
120
-
121
- Args:
122
- begin_index (`int`):
123
- The begin index for the scheduler.
124
- """
125
- self._begin_index = begin_index
126
-
127
- def scale_noise(
128
- self,
129
- sample: torch.FloatTensor,
130
- timestep: Union[float, torch.FloatTensor],
131
- noise: Optional[torch.FloatTensor] = None,
132
- ) -> torch.FloatTensor:
133
- """
134
- Forward process in flow-matching
135
-
136
- Args:
137
- sample (`torch.FloatTensor`):
138
- The input sample.
139
- timestep (`int`, *optional*):
140
- The current timestep in the diffusion chain.
141
-
142
- Returns:
143
- `torch.FloatTensor`:
144
- A scaled input sample.
145
- """
146
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
147
- sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
148
-
149
- if sample.device.type == "mps" and torch.is_floating_point(timestep):
150
- # mps does not support float64
151
- schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
152
- timestep = timestep.to(sample.device, dtype=torch.float32)
153
- else:
154
- schedule_timesteps = self.timesteps.to(sample.device)
155
- timestep = timestep.to(sample.device)
156
-
157
- # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
158
- if self.begin_index is None:
159
- step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
160
- elif self.step_index is not None:
161
- # add_noise is called after first denoising step (for inpainting)
162
- step_indices = [self.step_index] * timestep.shape[0]
163
- else:
164
- # add noise is called before first denoising step to create initial latent(img2img)
165
- step_indices = [self.begin_index] * timestep.shape[0]
166
-
167
- sigma = sigmas[step_indices].flatten()
168
- while len(sigma.shape) < len(sample.shape):
169
- sigma = sigma.unsqueeze(-1)
170
-
171
- sample = sigma * noise + (1.0 - sigma) * sample
172
-
173
- return sample
174
-
175
- def _sigma_to_t(self, sigma):
176
- return sigma * self.config.num_train_timesteps
177
-
178
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
179
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
180
-
181
- def set_timesteps(
182
- self,
183
- num_inference_steps: int = None,
184
- device: Union[str, torch.device] = None,
185
- sigmas: Optional[List[float]] = None,
186
- mu: Optional[float] = None,
187
- ):
188
- """
189
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
190
-
191
- Args:
192
- num_inference_steps (`int`):
193
- The number of diffusion steps used when generating samples with a pre-trained model.
194
- device (`str` or `torch.device`, *optional*):
195
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
196
- """
197
-
198
- if self.config.use_dynamic_shifting and mu is None:
199
- raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
200
-
201
- if sigmas is None:
202
- self.num_inference_steps = num_inference_steps
203
- timesteps = np.linspace(
204
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
205
- )
206
-
207
- sigmas = timesteps / self.config.num_train_timesteps
208
-
209
- if self.config.use_dynamic_shifting:
210
- sigmas = self.time_shift(mu, 1.0, sigmas)
211
- else:
212
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
213
-
214
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
215
- timesteps = sigmas * self.config.num_train_timesteps
216
-
217
- self.timesteps = timesteps.to(device=device)
218
- self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
219
-
220
- self._step_index = None
221
- self._begin_index = None
222
-
223
- def index_for_timestep(self, timestep, schedule_timesteps=None):
224
- if schedule_timesteps is None:
225
- schedule_timesteps = self.timesteps
226
-
227
- indices = (schedule_timesteps == timestep).nonzero()
228
-
229
- # The sigma index that is taken for the **very** first `step`
230
- # is always the second index (or the last index if there is only 1)
231
- # This way we can ensure we don't accidentally skip a sigma in
232
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
233
- pos = 1 if len(indices) > 1 else 0
234
-
235
- return indices[pos].item()
236
-
237
- def _init_step_index(self, timestep):
238
- if self.begin_index is None:
239
- if isinstance(timestep, torch.Tensor):
240
- timestep = timestep.to(self.timesteps.device)
241
- self._step_index = self.index_for_timestep(timestep)
242
- else:
243
- self._step_index = self._begin_index
244
-
245
- def step(
246
- self,
247
- model_output: torch.FloatTensor,
248
- timestep: Union[float, torch.FloatTensor],
249
- sample: torch.FloatTensor,
250
- s_churn: float = 0.0,
251
- s_tmin: float = 0.0,
252
- s_tmax: float = float("inf"),
253
- s_noise: float = 1.0,
254
- generator: Optional[torch.Generator] = None,
255
- return_dict: bool = True,
256
- ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
257
- """
258
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
259
- process from the learned model outputs (most often the predicted noise).
260
-
261
- Args:
262
- model_output (`torch.FloatTensor`):
263
- The direct output from learned diffusion model.
264
- timestep (`float`):
265
- The current discrete timestep in the diffusion chain.
266
- sample (`torch.FloatTensor`):
267
- A current instance of a sample created by the diffusion process.
268
- s_churn (`float`):
269
- s_tmin (`float`):
270
- s_tmax (`float`):
271
- s_noise (`float`, defaults to 1.0):
272
- Scaling factor for noise added to the sample.
273
- generator (`torch.Generator`, *optional*):
274
- A random number generator.
275
- return_dict (`bool`):
276
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
277
- tuple.
278
-
279
- Returns:
280
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
281
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
282
- returned, otherwise a tuple is returned where the first element is the sample tensor.
283
- """
284
-
285
- if (
286
- isinstance(timestep, int)
287
- or isinstance(timestep, torch.IntTensor)
288
- or isinstance(timestep, torch.LongTensor)
289
- ):
290
- raise ValueError(
291
- (
292
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
293
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
294
- " one of the `scheduler.timesteps` as a timestep."
295
- ),
296
- )
297
-
298
- if self.step_index is None:
299
- self._init_step_index(timestep)
300
-
301
- # Upcast to avoid precision issues when computing prev_sample
302
- sample = sample.to(torch.float32).to(model_output.device)
303
-
304
- sigma = self.sigmas[self.step_index].to(model_output.device)
305
- sigma_next = self.sigmas[self.step_index + 1].to(model_output.device)
306
-
307
- prev_sample = sample + (sigma_next - sigma) * model_output
308
-
309
- # Cast sample back to model compatible dtype
310
- prev_sample = prev_sample.to(model_output.dtype)
311
-
312
- # upon completion increase step index by one
313
- self._step_index += 1
314
-
315
- if not return_dict:
316
- return (prev_sample,)
317
-
318
- return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
319
-
320
- def __len__(self):
321
- return self.config.num_train_timesteps
322
-
323
-
324
- @dataclass
325
- class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
326
- prev_sample: torch.FloatTensor
327
- pred_original_sample: torch.FloatTensor
328
-
329
-
330
- class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
331
- _compatibles = []
332
- order = 1
333
-
334
- @register_to_config
335
- def __init__(
336
- self,
337
- num_train_timesteps: int = 1000,
338
- pcm_timesteps: int = 50,
339
- ):
340
- sigmas = np.linspace(0, 1, num_train_timesteps)
341
- step_ratio = num_train_timesteps // pcm_timesteps
342
-
343
- euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
344
- euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
345
-
346
- self.euler_timesteps = euler_timesteps
347
- self.sigmas = sigmas[self.euler_timesteps]
348
- self.sigmas = torch.from_numpy((self.sigmas.copy())).to(dtype=torch.float32)
349
- self.timesteps = self.sigmas * num_train_timesteps
350
- self._step_index = None
351
- self._begin_index = None
352
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
353
-
354
- @property
355
- def step_index(self):
356
- """
357
- The index counter for current timestep. It will increase 1 after each scheduler step.
358
- """
359
- return self._step_index
360
-
361
- @property
362
- def begin_index(self):
363
- """
364
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
365
- """
366
- return self._begin_index
367
-
368
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
369
- def set_begin_index(self, begin_index: int = 0):
370
- """
371
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
372
-
373
- Args:
374
- begin_index (`int`):
375
- The begin index for the scheduler.
376
- """
377
- self._begin_index = begin_index
378
-
379
- def _sigma_to_t(self, sigma):
380
- return sigma * self.config.num_train_timesteps
381
-
382
- def set_timesteps(
383
- self,
384
- num_inference_steps: int = None,
385
- device: Union[str, torch.device] = None,
386
- sigmas: Optional[List[float]] = None,
387
- ):
388
- """
389
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
390
-
391
- Args:
392
- num_inference_steps (`int`):
393
- The number of diffusion steps used when generating samples with a pre-trained model.
394
- device (`str` or `torch.device`, *optional*):
395
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
396
- """
397
- self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
398
- inference_indices = np.linspace(
399
- 0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
400
- )
401
- inference_indices = np.floor(inference_indices).astype(np.int64)
402
- inference_indices = torch.from_numpy(inference_indices).long()
403
-
404
- self.sigmas_ = self.sigmas[inference_indices]
405
- timesteps = self.sigmas_ * self.config.num_train_timesteps
406
- self.timesteps = timesteps.to(device=device)
407
- self.sigmas_ = torch.cat(
408
- [self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
409
- )
410
-
411
- self._step_index = None
412
- self._begin_index = None
413
-
414
- def index_for_timestep(self, timestep, schedule_timesteps=None):
415
- if schedule_timesteps is None:
416
- schedule_timesteps = self.timesteps
417
-
418
- indices = (schedule_timesteps == timestep).nonzero()
419
-
420
- # The sigma index that is taken for the **very** first `step`
421
- # is always the second index (or the last index if there is only 1)
422
- # This way we can ensure we don't accidentally skip a sigma in
423
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
424
- pos = 1 if len(indices) > 1 else 0
425
-
426
- return indices[pos].item()
427
-
428
- def _init_step_index(self, timestep):
429
- if self.begin_index is None:
430
- if isinstance(timestep, torch.Tensor):
431
- timestep = timestep.to(self.timesteps.device)
432
- self._step_index = self.index_for_timestep(timestep)
433
- else:
434
- self._step_index = self._begin_index
435
-
436
- def step(
437
- self,
438
- model_output: torch.FloatTensor,
439
- timestep: Union[float, torch.FloatTensor],
440
- sample: torch.FloatTensor,
441
- generator: Optional[torch.Generator] = None,
442
- return_dict: bool = True,
443
- ) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
444
- if (
445
- isinstance(timestep, int)
446
- or isinstance(timestep, torch.IntTensor)
447
- or isinstance(timestep, torch.LongTensor)
448
- ):
449
- raise ValueError(
450
- (
451
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
452
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
453
- " one of the `scheduler.timesteps` as a timestep."
454
- ),
455
- )
456
-
457
- if self.step_index is None:
458
- self._init_step_index(timestep)
459
-
460
- sample = sample.to(torch.float32).to(model_output.device)
461
-
462
- sigma = self.sigmas_[self.step_index].to(model_output.device)
463
- sigma_next = self.sigmas_[self.step_index + 1].to(model_output.device)
464
-
465
- prev_sample = sample + (sigma_next - sigma) * model_output
466
- prev_sample = prev_sample.to(model_output.dtype)
467
-
468
- pred_original_sample = sample + (1.0 - sigma) * model_output
469
- pred_original_sample = pred_original_sample.to(model_output.dtype)
470
-
471
- self._step_index += 1
472
-
473
- if not return_dict:
474
- return (prev_sample,)
475
-
476
- return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
477
- pred_original_sample=pred_original_sample)
478
-
479
- def __len__(self):
480
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/surface_loaders.py DELETED
@@ -1,233 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import numpy as np
16
- import torch
17
- import trimesh
18
-
19
-
20
- def normalize_mesh(mesh, scale=0.9999):
21
- """
22
- Normalize the mesh to fit inside a centered cube with a specified scale.
23
-
24
- The mesh is translated so that its bounding box center is at the origin,
25
- then uniformly scaled so that the longest side of the bounding box fits within [-scale, scale].
26
-
27
- Args:
28
- mesh (trimesh.Trimesh): Input mesh to normalize.
29
- scale (float, optional): Scaling factor to slightly shrink the mesh inside the unit cube. Default is 0.9999.
30
-
31
- Returns:
32
- trimesh.Trimesh: The normalized mesh with applied translation and scaling.
33
- """
34
- bbox = mesh.bounds
35
- center = (bbox[1] + bbox[0]) / 2
36
- scale_ = (bbox[1] - bbox[0]).max()
37
-
38
- mesh.apply_translation(-center)
39
- mesh.apply_scale(1 / scale_ * 2 * scale)
40
-
41
- return mesh
42
-
43
-
44
- def sample_pointcloud(mesh, num=200000):
45
- """
46
- Sample points uniformly from the surface of the mesh along with their corresponding face normals.
47
-
48
- Args:
49
- mesh (trimesh.Trimesh): Input mesh to sample from.
50
- num (int, optional): Number of points to sample. Default is 200000.
51
-
52
- Returns:
53
- Tuple[torch.Tensor, torch.Tensor]:
54
- - points: Sampled points as a float tensor of shape (num, 3).
55
- - normals: Corresponding normals as a float tensor of shape (num, 3).
56
- """
57
- points, face_idx = mesh.sample(num, return_index=True)
58
- normals = mesh.face_normals[face_idx]
59
- points = torch.from_numpy(points.astype(np.float32))
60
- normals = torch.from_numpy(normals.astype(np.float32))
61
- return points, normals
62
-
63
-
64
- def load_surface(mesh, num_points=8192):
65
- """
66
- Normalize the mesh, sample points and normals from its surface, and randomly select a subset.
67
-
68
- Args:
69
- mesh (trimesh.Trimesh): Input mesh to process.
70
- num_points (int, optional): Number of points to randomly select
71
- from the sampled surface points. Default is 8192.
72
-
73
- Returns:
74
- Tuple[torch.Tensor, trimesh.Trimesh]:
75
- - surface: Tensor of shape (1, num_points, 6), concatenating points and normals.
76
- - mesh: The normalized mesh.
77
- """
78
-
79
- mesh = normalize_mesh(mesh, scale=0.98)
80
- surface, normal = sample_pointcloud(mesh)
81
-
82
- rng = np.random.default_rng(seed=0)
83
- ind = rng.choice(surface.shape[0], num_points, replace=False)
84
- surface = torch.FloatTensor(surface[ind])
85
- normal = torch.FloatTensor(normal[ind])
86
-
87
- surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)
88
-
89
- return surface, mesh
90
-
91
-
92
- def sharp_sample_pointcloud(mesh, num=16384):
93
- """
94
- Sample points and normals preferentially from sharp edges of the mesh.
95
-
96
- Sharp edges are detected based on the angle between vertex normals and face normals.
97
- Points are sampled along these edges proportionally to edge length.
98
-
99
- Args:
100
- mesh (trimesh.Trimesh): Input mesh to sample from.
101
- num (int, optional): Number of points to sample from sharp edges. Default is 16384.
102
-
103
- Returns:
104
- Tuple[np.ndarray, np.ndarray]:
105
- - samples: Sampled points along sharp edges, shape (num, 3).
106
- - normals: Corresponding interpolated normals, shape (num, 3).
107
- """
108
- V = mesh.vertices
109
- N = mesh.face_normals
110
- VN = mesh.vertex_normals
111
- F = mesh.faces
112
- VN2 = np.ones(V.shape[0])
113
- for i in range(3):
114
- dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1)
115
- VN2[F[:, i]] = np.min(dot, axis=-1)
116
-
117
- sharp_mask = VN2 < 0.985
118
- # collect edge
119
- edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2]))
120
- edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0]))
121
- sharp_edge = ((sharp_mask[edge_a] * sharp_mask[edge_b]))
122
- edge_a = edge_a[sharp_edge > 0]
123
- edge_b = edge_b[sharp_edge > 0]
124
-
125
- sharp_verts_a = V[edge_a]
126
- sharp_verts_b = V[edge_b]
127
- sharp_verts_an = VN[edge_a]
128
- sharp_verts_bn = VN[edge_b]
129
-
130
- weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)
131
- weights /= np.sum(weights)
132
-
133
- random_number = np.random.rand(num)
134
- w = np.random.rand(num, 1)
135
- index = np.searchsorted(weights.cumsum(), random_number)
136
- samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]
137
- normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index]
138
- return samples, normals
139
-
140
-
141
- def load_surface_sharpegde(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, normalize_scale=0.9999):
142
- try:
143
- mesh_full = trimesh.util.concatenate(mesh.dump())
144
- except Exception as err:
145
- mesh_full = trimesh.util.concatenate(mesh)
146
- mesh_full = normalize_mesh(mesh_full, scale=normalize_scale)
147
-
148
- origin_num = mesh_full.faces.shape[0]
149
- original_vertices = mesh_full.vertices
150
- original_faces = mesh_full.faces
151
-
152
- mesh = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[:origin_num])
153
- mesh_fill = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[origin_num:])
154
- area = mesh.area
155
- area_fill = mesh_fill.area
156
- sample_num = 819200 // 2 # 499712 // 2
157
- num_fill = int(sample_num * (area_fill / (area + area_fill)))
158
- num = sample_num - num_fill
159
-
160
- random_surface, random_normal = sample_pointcloud(mesh, num=num)
161
- if num_fill == 0:
162
- random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3))
163
- else:
164
- random_surface_fill, random_normal_fill = sample_pointcloud(mesh_fill, num=num_fill)
165
- random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num)
166
-
167
- # save_surface
168
- surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16)
169
- surface_fill = np.concatenate((random_surface_fill, random_normal_fill), axis=1).astype(np.float16)
170
- sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype(np.float16)
171
- surface = np.concatenate((surface, surface_fill), axis=0)
172
- if sharpedge_flag:
173
- sharpedge_label = np.zeros((surface.shape[0], 1))
174
- surface = np.concatenate((surface, sharpedge_label), axis=1)
175
- sharpedge_label = np.ones((sharp_surface.shape[0], 1))
176
- sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
177
- rng = np.random.default_rng()
178
- ind = rng.choice(surface.shape[0], num_points, replace=False)
179
- surface = torch.FloatTensor(surface[ind])
180
- ind = rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)
181
- sharp_surface = torch.FloatTensor(sharp_surface[ind])
182
-
183
- return torch.cat([surface, sharp_surface], dim=0).unsqueeze(0), mesh_full
184
-
185
-
186
- class SurfaceLoader:
187
- def __init__(self, num_points=8192):
188
- self.num_points = num_points
189
-
190
- def __call__(self, mesh_or_mesh_path, num_points=None):
191
- if num_points is None:
192
- num_points = self.num_points
193
-
194
- mesh = mesh_or_mesh_path
195
- if isinstance(mesh, str):
196
- mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
197
- if isinstance(mesh, trimesh.scene.Scene):
198
- for idx, obj in enumerate(mesh.geometry.values()):
199
- if idx == 0:
200
- temp_mesh = obj
201
- else:
202
- temp_mesh = temp_mesh + obj
203
- mesh = temp_mesh
204
- surface, mesh = load_surface(mesh, num_points=num_points)
205
- return surface
206
-
207
-
208
- class SharpEdgeSurfaceLoader:
209
- def __init__(self, num_uniform_points=8192, num_sharp_points=8192, **kwargs):
210
- self.num_uniform_points = num_uniform_points
211
- self.num_sharp_points = num_sharp_points
212
- self.num_points = num_uniform_points + num_sharp_points
213
-
214
- def __call__(self, mesh_or_mesh_path, num_uniform_points=None,
215
- num_sharp_points=None, normalize_scale=0.9999):
216
- if num_uniform_points is None:
217
- num_uniform_points = self.num_uniform_points
218
- if num_sharp_points is None:
219
- num_sharp_points = self.num_sharp_points
220
-
221
- mesh = mesh_or_mesh_path
222
- if isinstance(mesh, str):
223
- mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
224
- if isinstance(mesh, trimesh.scene.Scene):
225
- for idx, obj in enumerate(mesh.geometry.values()):
226
- if idx == 0:
227
- temp_mesh = obj
228
- else:
229
- temp_mesh = temp_mesh + obj
230
- mesh = temp_mesh
231
- surface, mesh = load_surface_sharpegde(mesh, num_points=num_uniform_points,
232
- num_sharp_points=num_sharp_points, normalize_scale=normalize_scale)
233
- return surface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- from .misc import get_config_from_file
4
- from .misc import instantiate_from_config
5
- from .utils import get_logger, logger, synchronize_timer, smart_load_model
6
- from .voxelize import voxelize_from_point
 
 
 
 
 
 
 
ultrashape/utils/ema.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LitEma(nn.Module):
6
- def __init__(self, model, decay=0.9999, use_num_updates=True):
7
- super().__init__()
8
- if decay < 0.0 or decay > 1.0:
9
- raise ValueError('Decay must be between 0 and 1')
10
-
11
- self.m_name2s_name = {}
12
- self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
14
- else torch.tensor(-1, dtype=torch.int))
15
-
16
- for name, p in model.named_parameters():
17
- if p.requires_grad:
18
- # remove as '.'-character is not allowed in buffers
19
- s_name = name.replace('.', '_____')
20
- self.m_name2s_name.update({name: s_name})
21
- self.register_buffer(s_name, p.clone().detach().data)
22
-
23
- self.collected_params = []
24
-
25
- def forward(self, model):
26
- decay = self.decay
27
-
28
- if self.num_updates >= 0:
29
- self.num_updates += 1
30
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
31
-
32
- one_minus_decay = 1.0 - decay
33
-
34
- with torch.no_grad():
35
- m_param = dict(model.named_parameters())
36
- shadow_params = dict(self.named_buffers())
37
-
38
- for key in m_param:
39
- if m_param[key].requires_grad:
40
- sname = self.m_name2s_name[key]
41
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
- shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
- else:
44
- assert not key in self.m_name2s_name
45
-
46
- def copy_to(self, model):
47
- m_param = dict(model.named_parameters())
48
- shadow_params = dict(self.named_buffers())
49
- for key in m_param:
50
- if m_param[key].requires_grad:
51
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
- else:
53
- assert not key in self.m_name2s_name
54
-
55
- def store(self, model):
56
- """
57
- Save the current parameters for restoring later.
58
- Args:
59
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
- temporarily stored.
61
- """
62
- self.collected_params = [param.clone() for param in model.parameters()]
63
-
64
- def restore(self, model):
65
- """
66
- Restore the parameters stored with the `store` method.
67
- Useful to validate the model with EMA parameters without affecting the
68
- original optimization process. Store the parameters before the
69
- `copy_to` method. After validation (or model saving), use this to
70
- restore the former parameters.
71
- Args:
72
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
- updated with the stored parameters.
74
- """
75
- for c_param, param in zip(self.collected_params, model.parameters()):
76
- param.data.copy_(c_param.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/misc.py DELETED
@@ -1,200 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import importlib
4
- from omegaconf import OmegaConf, DictConfig, ListConfig
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from typing import Union
9
- from .utils import logger
10
- import os
11
-
12
-
13
- def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
14
- config_file = OmegaConf.load(config_file)
15
-
16
- if 'base_config' in config_file.keys():
17
- if config_file['base_config'] == "default_base":
18
- base_config = OmegaConf.create()
19
- # base_config = get_default_config()
20
- elif config_file['base_config'].endswith(".yaml"):
21
- base_config = get_config_from_file(config_file['base_config'])
22
- else:
23
- raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
24
-
25
- config_file = {key: value for key, value in config_file if key != "base_config"}
26
-
27
- return OmegaConf.merge(base_config, config_file)
28
-
29
- return config_file
30
-
31
-
32
- def get_obj_from_str(string, reload=False):
33
- module, cls = string.rsplit(".", 1)
34
- if reload:
35
- module_imp = importlib.import_module(module)
36
- importlib.reload(module_imp)
37
- return getattr(importlib.import_module(module, package=None), cls)
38
-
39
-
40
- def get_obj_from_config(config):
41
- if "target" not in config:
42
- raise KeyError("Expected key `target` to instantiate.")
43
-
44
- return get_obj_from_str(config["target"])
45
-
46
-
47
- def instantiate_from_config(config, **kwargs):
48
- if "target" not in config:
49
- raise KeyError("Expected key `target` to instantiate.")
50
-
51
- cls = get_obj_from_str(config["target"])
52
-
53
- if config.get("from_pretrained", None):
54
- return cls.from_pretrained(
55
- config["from_pretrained"],
56
- use_safetensors=config.get('use_safetensors', False),
57
- variant=config.get('variant', 'fp16'))
58
-
59
- params = config.get("params", dict())
60
- # params.update(kwargs)
61
- # instance = cls(**params)
62
- kwargs.update(params)
63
- instance = cls(**kwargs)
64
-
65
- return instance
66
-
67
-
68
- def instantiate_vae_from_config(config, **kwargs):
69
- if "target" not in config:
70
- raise KeyError("Expected key `target` to instantiate.")
71
-
72
- cls = get_obj_from_str(config["target"])
73
-
74
- if config.get("from_pretrained", None):
75
- return cls.from_pretrained(
76
- config["from_pretrained"],
77
- params=config.get("params", dict()),
78
- use_safetensors=config.get('use_safetensors', False),
79
- variant=config.get('variant', 'fp16'))
80
-
81
- params = config.get("params", dict())
82
- kwargs.update(params)
83
- instance = cls(**kwargs)
84
-
85
- return instance
86
-
87
- def instantiate_vae_from_config_local(config, **kwargs):
88
- if "target" not in config:
89
- raise KeyError("Expected key `target` to instantiate.")
90
-
91
- cls = get_obj_from_str(config["target"])
92
-
93
- if not config.get("from_pretrained", None):
94
- raise FileNotFoundError(f"Need from_pretrained!")
95
-
96
- ckpt_path = config["from_pretrained"]
97
-
98
- logger.info(f"Loading model from {ckpt_path}")
99
- if not os.path.exists(ckpt_path):
100
- raise FileNotFoundError(f"Model file {ckpt_path} not found")
101
- ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
102
-
103
- if 'state_dict' not in ckpt:
104
- # deepspeed ckpt
105
- state_dict = {}
106
- for k in ckpt.keys():
107
- new_k = k.replace('vae_model.', '')
108
- state_dict[new_k] = ckpt[k]
109
- else:
110
- state_dict = ckpt["state_dict"]
111
-
112
- params = config.get("params", dict())
113
- kwargs.update(params)
114
- instance = cls(**kwargs)
115
-
116
-
117
- missing, unexpected = instance.load_state_dict(state_dict)
118
- print(f"VAE Missing Keys: {missing}")
119
- print(f"VAE Unexpected Keys: {unexpected}")
120
-
121
- return instance
122
-
123
- def disabled_train(self, mode=True):
124
- """Overwrite model.train with this function to make sure train/eval mode
125
- does not change anymore."""
126
- return self
127
-
128
-
129
- def instantiate_non_trainable_model(config):
130
- model = instantiate_from_config(config)
131
- model = model.eval()
132
- model.train = disabled_train
133
- for param in model.parameters():
134
- param.requires_grad = False
135
-
136
- return model
137
-
138
-
139
- def instantiate_vae_model(config, requires_grad=False):
140
- model = instantiate_vae_from_config(config)
141
- model = model.eval()
142
- model.train = disabled_train
143
- for param in model.parameters():
144
- param.requires_grad = requires_grad
145
-
146
- return model
147
-
148
- def instantiate_vae_model_local(config, requires_grad=False):
149
- model = instantiate_vae_from_config_local(config)
150
- model = model.eval()
151
- model.train = disabled_train
152
- for param in model.parameters():
153
- param.requires_grad = requires_grad
154
-
155
- return model
156
-
157
- def is_dist_avail_and_initialized():
158
- if not dist.is_available():
159
- return False
160
- if not dist.is_initialized():
161
- return False
162
- return True
163
-
164
-
165
- def get_rank():
166
- if not is_dist_avail_and_initialized():
167
- return 0
168
- return dist.get_rank()
169
-
170
-
171
- def get_world_size():
172
- if not is_dist_avail_and_initialized():
173
- return 1
174
- return dist.get_world_size()
175
-
176
-
177
- def all_gather_batch(tensors):
178
- """
179
- Performs all_gather operation on the provided tensors.
180
- """
181
- # Queue the gathered tensors
182
- world_size = get_world_size()
183
- # There is no need for reduction in the single-proc case
184
- if world_size == 1:
185
- return tensors
186
- tensor_list = []
187
- output_tensor = []
188
- for tensor in tensors:
189
- tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
190
- dist.all_gather(
191
- tensor_all,
192
- tensor,
193
- async_op=False # performance opt
194
- )
195
-
196
- tensor_list.append(tensor_all)
197
-
198
- for tensor_all in tensor_list:
199
- output_tensor.append(torch.cat(tensor_all, dim=0))
200
- return output_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/trainings/__init__.py DELETED
@@ -1 +0,0 @@
1
- # -*- coding: utf-8 -*-
 
 
ultrashape/utils/trainings/callback.py DELETED
@@ -1,213 +0,0 @@
1
- # ------------------------------------------------------------------------------------
2
- # Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
3
- # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
- # ------------------------------------------------------------------------------------
5
-
6
- import os
7
- import time
8
- import wandb
9
- import numpy as np
10
- from PIL import Image
11
- from pathlib import Path
12
- from omegaconf import OmegaConf, DictConfig
13
- from typing import Tuple, Generic, Dict, Callable, Optional, Any
14
- from pprint import pprint
15
-
16
- import torch
17
- import torchvision
18
- import pytorch_lightning as pl
19
- import pytorch_lightning.loggers
20
- from pytorch_lightning.loggers import WandbLogger
21
- from pytorch_lightning.loggers.logger import DummyLogger
22
- from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
23
- from pytorch_lightning.callbacks import Callback
24
-
25
- from functools import wraps
26
-
27
- def node_zero_only(fn: Callable) -> Callable:
28
- @wraps(fn)
29
- def wrapped_fn(*args, **kwargs) -> Optional[Any]:
30
- if node_zero_only.node == 0:
31
- return fn(*args, **kwargs)
32
- return None
33
- return wrapped_fn
34
-
35
- node_zero_only.node = getattr(node_zero_only, 'node', int(os.environ.get('NODE_RANK', 0)))
36
-
37
- def node_zero_experiment(fn: Callable) -> Callable:
38
- """Returns the real experiment on rank 0 and otherwise the DummyExperiment."""
39
- @wraps(fn)
40
- def experiment(self):
41
- @node_zero_only
42
- def get_experiment():
43
- return fn(self)
44
- return get_experiment() or DummyLogger.experiment
45
- return experiment
46
-
47
- # customize wandb for node 0 only
48
- class MyWandbLogger(WandbLogger):
49
- @WandbLogger.experiment.getter
50
- @node_zero_experiment
51
- def experiment(self):
52
- return super().experiment
53
-
54
- class SetupCallback(Callback):
55
- def __init__(self, config: DictConfig, exp_config: DictConfig,
56
- basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
57
- super().__init__()
58
- self.logdir = basedir / logdir
59
- self.ckptdir = basedir / ckptdir
60
- self.config = config
61
- self.exp_config = exp_config
62
-
63
- # def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
64
- # if trainer.global_rank == 0:
65
- # # Create logdirs and save configs
66
- # os.makedirs(self.logdir, exist_ok=True)
67
- # os.makedirs(self.ckptdir, exist_ok=True)
68
- #
69
- # print("Experiment config")
70
- # print(self.exp_config.pretty())
71
- #
72
- # print("Model config")
73
- # print(self.config.pretty())
74
-
75
- def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
76
- if trainer.global_rank == 0:
77
- # Create logdirs and save configs
78
- os.makedirs(self.logdir, exist_ok=True)
79
- os.makedirs(self.ckptdir, exist_ok=True)
80
-
81
- # print("Experiment config")
82
- # pprint(self.exp_config)
83
- #
84
- # print("Model config")
85
- # pprint(self.config)
86
-
87
-
88
- class ImageLogger(Callback):
89
- def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True,
90
- increase_log_steps: bool = True) -> None:
91
-
92
- super().__init__()
93
- self.batch_freq = batch_frequency
94
- self.max_images = max_images
95
- self.logger_log_images = {
96
- pl.loggers.WandbLogger: self._wandb,
97
- pl.loggers.TestTubeLogger: self._testtube,
98
- }
99
- self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
100
- if not increase_log_steps:
101
- self.log_steps = [self.batch_freq]
102
- self.clamp = clamp
103
-
104
- @rank_zero_only
105
- def _wandb(self, pl_module, images, batch_idx, split):
106
- # raise ValueError("No way wandb")
107
- grids = dict()
108
- for k in images:
109
- grid = torchvision.utils.make_grid(images[k])
110
- grids[f"{split}/{k}"] = wandb.Image(grid)
111
- pl_module.logger.experiment.log(grids)
112
-
113
- @rank_zero_only
114
- def _testtube(self, pl_module, images, batch_idx, split):
115
- for k in images:
116
- grid = torchvision.utils.make_grid(images[k])
117
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
118
-
119
- tag = f"{split}/{k}"
120
- pl_module.logger.experiment.add_image(
121
- tag, grid,
122
- global_step=pl_module.global_step)
123
-
124
- @rank_zero_only
125
- def log_local(self, save_dir: str, split: str, images: Dict,
126
- global_step: int, current_epoch: int, batch_idx: int) -> None:
127
- root = os.path.join(save_dir, "results", split)
128
- os.makedirs(root, exist_ok=True)
129
- for k in images:
130
- grid = torchvision.utils.make_grid(images[k], nrow=4)
131
-
132
- grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
133
- grid = grid.numpy()
134
- grid = (grid * 255).astype(np.uint8)
135
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
136
- k,
137
- global_step,
138
- current_epoch,
139
- batch_idx)
140
- path = os.path.join(root, filename)
141
- os.makedirs(os.path.split(path)[0], exist_ok=True)
142
- Image.fromarray(grid).save(path)
143
-
144
- def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int,
145
- split: str = "train") -> None:
146
- if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
147
- hasattr(pl_module, "log_images") and
148
- callable(pl_module.log_images) and
149
- self.max_images > 0):
150
- logger = type(pl_module.logger)
151
-
152
- is_train = pl_module.training
153
- if is_train:
154
- pl_module.eval()
155
-
156
- with torch.no_grad():
157
- images = pl_module.log_images(batch, split=split, pl_module=pl_module)
158
-
159
- for k in images:
160
- N = min(images[k].shape[0], self.max_images)
161
- images[k] = images[k][:N].detach().cpu()
162
- if self.clamp:
163
- images[k] = images[k].clamp(0, 1)
164
-
165
- self.log_local(pl_module.logger.save_dir, split, images,
166
- pl_module.global_step, pl_module.current_epoch, batch_idx)
167
-
168
- logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
169
- logger_log_images(pl_module, images, pl_module.global_step, split)
170
-
171
- if is_train:
172
- pl_module.train()
173
-
174
- def check_frequency(self, batch_idx: int) -> bool:
175
- if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
176
- try:
177
- self.log_steps.pop(0)
178
- except IndexError:
179
- pass
180
- return True
181
- return False
182
-
183
- def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
184
- outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:
185
- self.log_img(pl_module, batch, batch_idx, split="train")
186
-
187
- def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
188
- outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],
189
- dataloader_idx: int, batch_idx: int) -> None:
190
- self.log_img(pl_module, batch, batch_idx, split="val")
191
-
192
-
193
- class CUDACallback(Callback):
194
- # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
195
- def on_train_epoch_start(self, trainer, pl_module):
196
- # Reset the memory use counter
197
- torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
198
- torch.cuda.synchronize(trainer.root_gpu)
199
- self.start_time = time.time()
200
-
201
- def on_train_epoch_end(self, trainer, pl_module, outputs):
202
- torch.cuda.synchronize(trainer.root_gpu)
203
- max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
204
- epoch_time = time.time() - self.start_time
205
-
206
- try:
207
- max_memory = trainer.training_type_plugin.reduce(max_memory)
208
- epoch_time = trainer.training_type_plugin.reduce(epoch_time)
209
-
210
- rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
211
- rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
212
- except AttributeError:
213
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/trainings/lr_scheduler.py DELETED
@@ -1,53 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import numpy as np
16
-
17
-
18
- class BaseScheduler(object):
19
-
20
- def schedule(self, n, **kwargs):
21
- raise NotImplementedError
22
-
23
-
24
- class LambdaWarmUpCosineFactorScheduler(BaseScheduler):
25
- """
26
- note: use with a base_lr of 1.0
27
- """
28
- def __init__(self, warm_up_steps, f_min, f_max, f_start, max_decay_steps, verbosity_interval=0, **ignore_kwargs):
29
- self.lr_warm_up_steps = warm_up_steps
30
- self.f_start = f_start
31
- self.f_min = f_min
32
- self.f_max = f_max
33
- self.lr_max_decay_steps = max_decay_steps
34
- self.last_f = 0.
35
- self.verbosity_interval = verbosity_interval
36
-
37
- def schedule(self, n, **kwargs):
38
- if self.verbosity_interval > 0:
39
- if n % self.verbosity_interval == 0:
40
- print(f"current step: {n}, recent lr-multiplier: {self.f_start}")
41
- if n < self.lr_warm_up_steps:
42
- f = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
43
- self.last_f = f
44
- return f
45
- else:
46
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
47
- t = min(t, 1.0)
48
- f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi))
49
- self.last_f = f
50
- return f
51
-
52
- def __call__(self, n, **kwargs):
53
- return self.schedule(n, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/trainings/mesh.py DELETED
@@ -1,128 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
4
- # except for the third-party components listed below.
5
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
6
- # in the repsective licenses of these third-party components.
7
- # Users must comply with all terms and conditions of original licenses of these third-party
8
- # components and must ensure that the usage of the third party components adheres to
9
- # all relevant laws and regulations.
10
-
11
- # For avoidance of doubts, Hunyuan 3D means the large language models and
12
- # their software and algorithms, including trained model weights, parameters (including
13
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
14
- # fine-tuning enabling code and other elements of the foregoing made publicly available
15
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
16
-
17
- import os
18
- import cv2
19
- import numpy as np
20
- import PIL.Image
21
- from typing import Optional
22
-
23
- import trimesh
24
-
25
-
26
- def save_obj(pointnp_px3, facenp_fx3, fname):
27
- fid = open(fname, "w")
28
- write_str = ""
29
- for pidx, p in enumerate(pointnp_px3):
30
- pp = p
31
- write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
32
-
33
- for i, f in enumerate(facenp_fx3):
34
- f1 = f + 1
35
- write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
36
- fid.write(write_str)
37
- fid.close()
38
- return
39
-
40
-
41
- def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
42
- fol, na = os.path.split(fname)
43
- na, _ = os.path.splitext(na)
44
-
45
- matname = "%s/%s.mtl" % (fol, na)
46
- fid = open(matname, "w")
47
- fid.write("newmtl material_0\n")
48
- fid.write("Kd 1 1 1\n")
49
- fid.write("Ka 0 0 0\n")
50
- fid.write("Ks 0.4 0.4 0.4\n")
51
- fid.write("Ns 10\n")
52
- fid.write("illum 2\n")
53
- fid.write("map_Kd %s.png\n" % na)
54
- fid.close()
55
- ####
56
-
57
- fid = open(fname, "w")
58
- fid.write("mtllib %s.mtl\n" % na)
59
-
60
- for pidx, p3 in enumerate(pointnp_px3):
61
- pp = p3
62
- fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
63
-
64
- for pidx, p2 in enumerate(tcoords_px2):
65
- pp = p2
66
- fid.write("vt %f %f\n" % (pp[0], pp[1]))
67
-
68
- fid.write("usemtl material_0\n")
69
- for i, f in enumerate(facenp_fx3):
70
- f1 = f + 1
71
- f2 = facetex_fx3[i] + 1
72
- fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
73
- fid.close()
74
-
75
- PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
76
- os.path.join(fol, "%s.png" % na))
77
-
78
- return
79
-
80
-
81
- class MeshOutput(object):
82
-
83
- def __init__(self,
84
- mesh_v: np.ndarray,
85
- mesh_f: np.ndarray,
86
- vertex_colors: Optional[np.ndarray] = None,
87
- uvs: Optional[np.ndarray] = None,
88
- mesh_tex_idx: Optional[np.ndarray] = None,
89
- tex_map: Optional[np.ndarray] = None):
90
-
91
- self.mesh_v = mesh_v
92
- self.mesh_f = mesh_f
93
- self.vertex_colors = vertex_colors
94
- self.uvs = uvs
95
- self.mesh_tex_idx = mesh_tex_idx
96
- self.tex_map = tex_map
97
-
98
- def contain_uv_texture(self):
99
- return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
100
-
101
- def contain_vertex_colors(self):
102
- return self.vertex_colors is not None
103
-
104
- def export(self, fname):
105
-
106
- if self.contain_uv_texture():
107
- savemeshtes2(
108
- self.mesh_v,
109
- self.uvs,
110
- self.mesh_f,
111
- self.mesh_tex_idx,
112
- self.tex_map,
113
- fname
114
- )
115
-
116
- elif self.contain_vertex_colors():
117
- mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
118
- mesh_obj.export(fname)
119
-
120
- else:
121
- save_obj(
122
- self.mesh_v,
123
- self.mesh_f,
124
- fname
125
- )
126
-
127
-
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/trainings/mesh_log_callback.py DELETED
@@ -1,342 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
4
- # except for the third-party components listed below.
5
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
6
- # in the repsective licenses of these third-party components.
7
- # Users must comply with all terms and conditions of original licenses of these third-party
8
- # components and must ensure that the usage of the third party components adheres to
9
- # all relevant laws and regulations.
10
-
11
- # For avoidance of doubts, Hunyuan 3D means the large language models and
12
- # their software and algorithms, including trained model weights, parameters (including
13
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
14
- # fine-tuning enabling code and other elements of the foregoing made publicly available
15
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
16
-
17
- import json
18
- import math
19
- import os
20
- from typing import Tuple, Generic, Dict, List, Union, Optional
21
-
22
- import trimesh
23
- import numpy as np
24
- import pytorch_lightning as pl
25
- import pytorch_lightning.loggers
26
- import torch
27
- import torchvision
28
- from pytorch_lightning.callbacks import Callback
29
- from pytorch_lightning.utilities import rank_zero_only
30
-
31
- from hy3dshape.pipelines import export_to_trimesh
32
- from hy3dshape.utils.trainings.mesh import MeshOutput
33
- from hy3dshape.utils.visualizers import html_util
34
- from hy3dshape.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
35
-
36
-
37
- class ImageConditionalASLDiffuserLogger(Callback):
38
- def __init__(self,
39
- step_frequency: int,
40
- num_samples: int = 1,
41
- mean: Optional[Union[List[float], Tuple[float]]] = None,
42
- std: Optional[Union[List[float], Tuple[float]]] = None,
43
- bounds: Union[List[float], Tuple[float]] = (-1.1, -1.1, -1.1, 1.1, 1.1, 1.1),
44
- **kwargs) -> None:
45
-
46
- super().__init__()
47
- self.bbox_size = np.array(bounds[3:6]) - np.array(bounds[0:3])
48
-
49
- if mean is not None:
50
- mean = np.asarray(mean)
51
-
52
- if std is not None:
53
- std = np.asarray(std)
54
-
55
- self.mean = mean
56
- self.std = std
57
-
58
- self.step_freq = step_frequency
59
- self.num_samples = num_samples
60
- self.has_train_logged = False
61
- self.logger_log_images = {
62
- pl.loggers.WandbLogger: self._wandb,
63
- }
64
-
65
- self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
66
-
67
- @rank_zero_only
68
- def _wandb(self, pl_module, images, batch_idx, split):
69
- # raise ValueError("No way wandb")
70
- grids = dict()
71
- for k in images:
72
- grid = torchvision.utils.make_grid(images[k])
73
- grids[f"{split}/{k}"] = wandb.Image(grid)
74
- pl_module.logger.experiment.log(grids)
75
-
76
- def log_local(self,
77
- outputs: List[List['Latent2MeshOutput']],
78
- images: Union[np.ndarray, List[np.ndarray]],
79
- description: List[str],
80
- keys: List[str],
81
- save_dir: str, split: str,
82
- global_step: int, current_epoch: int, batch_idx: int,
83
- prog_bar: bool = False,
84
- multi_views=None, # yf ...
85
- ) -> None:
86
-
87
- folder = "gs-{:010}_e-{:06}_b-{:06}".format(global_step, current_epoch, batch_idx)
88
- visual_dir = os.path.join(save_dir, "visuals", split, folder)
89
- os.makedirs(visual_dir, exist_ok=True)
90
-
91
- num_samples = len(images)
92
-
93
- for i in range(num_samples):
94
- key_i = keys[i]
95
- image_i = self.denormalize_image(images[i])
96
- shape_tag_i = description[i]
97
-
98
- for j in range(1):
99
- mesh = outputs[j][i]
100
- if mesh is None:
101
- continue
102
-
103
- mesh_v = mesh.mesh_v.copy()
104
- mesh_v[:, 0] += j * np.max(self.bbox_size)
105
- self.viewer.add_mesh(mesh_v, mesh.mesh_f)
106
-
107
- image_tag = html_util.to_image_embed_tag(image_i)
108
- mesh_tag = self.viewer.to_html(html_frame=False)
109
-
110
- table_tag = f"""
111
- <table border = "1">
112
- <caption> {shape_tag_i} - {key_i} </caption>
113
- <caption> Input Image | Generated Mesh </caption>
114
- <tr>
115
- <td>{image_tag}</td>
116
- <td>{mesh_tag}</td>
117
- </tr>
118
- </table>
119
- """
120
-
121
- if multi_views is not None:
122
- multi_views_i = self.make_grid(multi_views[i])
123
- views_tag = html_util.to_image_embed_tag(self.denormalize_image(multi_views_i))
124
- table_tag = f"""
125
- <table border = "1">
126
- <caption> {shape_tag_i} - {key_i} </caption>
127
- <caption> Input Image | Generated Mesh </caption>
128
- <tr>
129
- <td>{image_tag}</td>
130
- <td>{views_tag}</td>
131
- <td>{mesh_tag}</td>
132
- </tr>
133
- </table>
134
- """
135
-
136
- html_frame = html_util.to_html_frame(table_tag)
137
- if len(key_i) > 100:
138
- key_i = key_i[:100]
139
- with open(os.path.join(visual_dir, f"{key_i}.html"), "w") as writer:
140
- writer.write(html_frame)
141
-
142
- self.viewer.reset()
143
-
144
- def log_sample(self,
145
- pl_module: pl.LightningModule,
146
- batch: Dict[str, torch.FloatTensor],
147
- batch_idx: int,
148
- split: str = "train") -> None:
149
- """
150
-
151
- Args:
152
- pl_module:
153
- batch (dict): the batch sample information, and it contains:
154
- - surface (torch.FloatTensor):
155
- - image (torch.FloatTensor):
156
- batch_idx (int):
157
- split (str):
158
-
159
- Returns:
160
-
161
- """
162
-
163
- is_train = pl_module.training
164
- if is_train:
165
- pl_module.eval()
166
-
167
- batch_size = len(batch["surface"])
168
- replace = batch_size < self.num_samples
169
- ids = np.random.choice(batch_size, self.num_samples, replace=replace)
170
-
171
- with torch.no_grad():
172
- # run text to mesh
173
- # keys = [batch["__key__"][i] for i in ids]
174
- keys = [f'key_{i}' for i in ids]
175
- # texts = [batch["text"][i] for i in ids]
176
- texts = [f'text_{i}'for i in ids]
177
- # description = [batch["description"][i] for i in ids]
178
- description = [f'desc_{i}_{os.path.splitext(os.path.basename(batch["uid"][i]))[0]}' for i in ids]
179
- images = batch["image"][ids]
180
- mask_input = batch["mask"][ids] if 'mask' in batch else None
181
- # uids = batch["uid"][ids]
182
- sample_batch = {
183
- "__key__": keys,
184
- "image": images,
185
- 'text': texts,
186
- 'mask': mask_input,
187
- }
188
-
189
- # if 'cam_parm' in batch:
190
- # sample_batch['cam_parm'] = batch['cam_parm'][ids]
191
-
192
- # if 'multi_views' in batch: # yf ...
193
- # sample_batch['multi_views'] = batch['multi_views'][ids]
194
-
195
- outputs = pl_module.sample(
196
- batch=sample_batch,
197
- output_type='latents2mesh'
198
- )
199
-
200
- images = images.cpu().float().numpy()
201
- # images = self.denormalize_image(images)
202
- # images = np.transpose(images, (0, 2, 3, 1))
203
- # images = ((images + 1) / 2 * 255).astype(np.uint8)
204
-
205
- self.log_local(outputs, images, description, keys, pl_module.logger.save_dir, split,
206
- pl_module.global_step, pl_module.current_epoch, batch_idx, prog_bar=False,
207
- multi_views=sample_batch.get('multi_views'))
208
-
209
- if is_train: pl_module.train()
210
-
211
- def make_grid(self, images): # return (3,h,w) in (0,1) ...
212
- images_resized = []
213
- for img in images:
214
- img_resized = torchvision.transforms.functional.resize(img, (320, 320))
215
- images_resized.append(img_resized)
216
- image = torchvision.utils.make_grid(images_resized, nrow=2, padding=5, pad_value=255)
217
-
218
- image = image.cpu().numpy()
219
- # image = np.transpose(image, (1, 2, 0))
220
- # image = (image * 255).astype(np.uint8)
221
-
222
- return image
223
-
224
- def check_frequency(self, step: int) -> bool:
225
- if step % self.step_freq == 0:
226
- return True
227
- return False
228
-
229
- def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
230
- outputs: Generic, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> None:
231
-
232
- if (self.check_frequency(pl_module.global_step) and # batch_idx % self.batch_freq == 0
233
- hasattr(pl_module, "sample") and
234
- callable(pl_module.sample) and
235
- self.num_samples > 0):
236
- self.log_sample(pl_module, batch, batch_idx, split="train")
237
- self.has_train_logged = True
238
-
239
- def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
240
- outputs: Generic, batch: Dict[str, torch.FloatTensor],
241
- dataloader_idx: int, batch_idx: int) -> None:
242
-
243
- if self.has_train_logged:
244
- self.log_sample(pl_module, batch, batch_idx, split="val")
245
- self.has_train_logged = False
246
-
247
- def denormalize_image(self, image):
248
- """
249
-
250
- Args:
251
- image (np.ndarray): [3, h, w]
252
-
253
- Returns:
254
- image (np.ndarray): [h, w, 3], np.uint8, [0, 255].
255
- """
256
- # image = np.transpose(image, (0, 2, 3, 1))
257
- image = np.transpose(image, (1, 2, 0))
258
-
259
- if self.std is not None:
260
- image = image * self.std
261
-
262
- if self.mean is not None:
263
- image = image + self.mean
264
-
265
- image = (image * 255).astype(np.uint8)
266
-
267
- return image
268
-
269
-
270
- class ImageConditionalFixASLDiffuserLogger(Callback):
271
- def __init__(
272
- self,
273
- step_frequency: int,
274
- test_data_path: str,
275
- max_size: int = None,
276
- save_dir: str = 'infer',
277
- **kwargs,
278
- ) -> None:
279
- super().__init__()
280
- self.step_freq = step_frequency
281
- self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
282
-
283
- self.test_data_path = test_data_path
284
- with open(self.test_data_path, 'r') as f:
285
- data = json.load(f)
286
- self.file_list = data['file_list']
287
- # self.file_folder = data['file_folder']
288
- if max_size is not None:
289
- self.file_list = self.file_list[:max_size]
290
- self.kwargs = kwargs
291
- self.save_dir = save_dir
292
-
293
- def on_train_batch_end(
294
- self,
295
- trainer: pl.trainer.Trainer,
296
- pl_module: pl.LightningModule,
297
- outputs: Generic,
298
- batch: Dict[str, torch.FloatTensor],
299
- batch_idx: int,
300
- ):
301
- if pl_module.global_step % self.step_freq == 0:
302
- with open(self.test_data_path, 'r') as f:
303
- data = json.load(f)
304
- self.file_list = data['file_list']
305
- is_train = pl_module.training
306
- if is_train:
307
- pl_module.eval()
308
-
309
- # folder_path = self.file_folder
310
- # folder_name = os.path.basename(folder_path)
311
- folder = "gs-{:010}_e-{:06}_b-{:06}".format(pl_module.global_step, pl_module.current_epoch, batch_idx)
312
- visual_dir = os.path.join(pl_module.logger.save_dir, self.save_dir, folder)
313
- os.makedirs(visual_dir, exist_ok=True)
314
-
315
- image_paths = self.file_list
316
- chunk_size = math.ceil(len(image_paths) / trainer.world_size)
317
- if pl_module.global_rank == trainer.world_size - 1:
318
- image_paths = image_paths[pl_module.global_rank * chunk_size:]
319
- else:
320
- image_paths = image_paths[pl_module.global_rank * chunk_size:(pl_module.global_rank + 1) * chunk_size]
321
-
322
- print(f'Rank{pl_module.global_rank}: processing {len(image_paths)}|{len(self.file_list)} images')
323
- for image_path in image_paths:
324
- # if folder_path in image_path:
325
- # save_path = image_path.replace(folder_path, visual_dir)
326
- # else:
327
- save_path = os.path.join(visual_dir, os.path.basename(image_path))
328
- save_path = os.path.splitext(save_path)[0] + '.glb'
329
-
330
- if isinstance(image_path, str):
331
- print(image_path)
332
-
333
- with torch.no_grad():
334
- mesh = pl_module.sample(batch={"image": image_path}, **self.kwargs)[0][0]
335
- if isinstance(mesh, tuple) and len(mesh)==2:
336
- mesh = export_to_trimesh(mesh)
337
- elif isinstance(mesh, trimesh.Trimesh):
338
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
339
- mesh.export(save_path)
340
-
341
- if is_train:
342
- pl_module.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/trainings/peft.py DELETED
@@ -1,78 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
4
- # except for the third-party components listed below.
5
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
6
- # in the repsective licenses of these third-party components.
7
- # Users must comply with all terms and conditions of original licenses of these third-party
8
- # components and must ensure that the usage of the third party components adheres to
9
- # all relevant laws and regulations.
10
-
11
- # For avoidance of doubts, Hunyuan 3D means the large language models and
12
- # their software and algorithms, including trained model weights, parameters (including
13
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
14
- # fine-tuning enabling code and other elements of the foregoing made publicly available
15
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
16
-
17
- import os
18
- from pytorch_lightning.callbacks import Callback
19
- from omegaconf import OmegaConf, ListConfig
20
-
21
- class PeftSaveCallback(Callback):
22
- def __init__(self, peft_model, save_dir: str, save_every_n_steps: int = None):
23
- super().__init__()
24
- self.peft_model = peft_model
25
- self.save_dir = save_dir
26
- self.save_every_n_steps = save_every_n_steps
27
- os.makedirs(self.save_dir, exist_ok=True)
28
-
29
- def recursive_convert(self, obj):
30
- from omegaconf import OmegaConf, ListConfig
31
- if isinstance(obj, (OmegaConf, ListConfig)):
32
- return OmegaConf.to_container(obj, resolve=True)
33
- elif isinstance(obj, dict):
34
- return {k: self.recursive_convert(v) for k, v in obj.items()}
35
- elif isinstance(obj, list):
36
- return [self.recursive_convert(i) for i in obj]
37
- elif isinstance(obj, type):
38
- # 避免修改类对象
39
- return obj
40
- elif hasattr(obj, '__dict__'):
41
- for attr_name, attr_value in vars(obj).items():
42
- setattr(obj, attr_name, self.recursive_convert(attr_value))
43
- return obj
44
- else:
45
- return obj
46
-
47
- # def recursive_convert(self, obj):
48
- # if isinstance(obj, (OmegaConf, ListConfig)):
49
- # return OmegaConf.to_container(obj, resolve=True)
50
- # elif isinstance(obj, dict):
51
- # return {k: self.recursive_convert(v) for k, v in obj.items()}
52
- # elif isinstance(obj, list):
53
- # return [self.recursive_convert(i) for i in obj]
54
- # elif hasattr(obj, '__dict__'):
55
- # for attr_name, attr_value in vars(obj).items():
56
- # setattr(obj, attr_name, self.recursive_convert(attr_value))
57
- # return obj
58
- # else:
59
- # return obj
60
-
61
- def _convert_peft_config(self):
62
- pc = self.peft_model.peft_config
63
- self.peft_model.peft_config = self.recursive_convert(pc)
64
-
65
- def on_train_epoch_end(self, trainer, pl_module):
66
- self._convert_peft_config()
67
- save_path = os.path.join(self.save_dir, f"epoch_{trainer.current_epoch}")
68
- self.peft_model.save_pretrained(save_path)
69
- print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
70
-
71
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
72
- if self.save_every_n_steps is not None:
73
- global_step = trainer.global_step
74
- if global_step % self.save_every_n_steps == 0 and global_step > 0:
75
- self._convert_peft_config()
76
- save_path = os.path.join(self.save_dir, f"step_{global_step}")
77
- self.peft_model.save_pretrained(save_path)
78
- print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/typing.py DELETED
@@ -1,41 +0,0 @@
1
- """
2
- This module contains type annotations for the project, using
3
- 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
4
- 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
5
-
6
- Two types of typing checking can be used:
7
- 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
8
- 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
9
- """
10
-
11
- # Basic types
12
- from typing import (
13
- Any,
14
- Callable,
15
- Dict,
16
- Iterable,
17
- List,
18
- Literal,
19
- NamedTuple,
20
- NewType,
21
- Optional,
22
- Sized,
23
- Tuple,
24
- Type,
25
- TypeVar,
26
- Union,
27
- Sequence,
28
- )
29
-
30
- # Tensor dtype
31
- # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
32
- from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
33
-
34
- # Config type
35
- from omegaconf import DictConfig
36
-
37
- # PyTorch Tensor type
38
- from torch import Tensor
39
-
40
- # Runtime type checking decorator
41
- from typeguard import typechecked as typechecker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/utils.py DELETED
@@ -1,128 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import logging
16
- import os
17
- from functools import wraps
18
-
19
- import torch
20
-
21
-
22
- def get_logger(name):
23
- logger = logging.getLogger(name)
24
- logger.setLevel(logging.INFO)
25
-
26
- console_handler = logging.StreamHandler()
27
- console_handler.setLevel(logging.INFO)
28
-
29
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
30
- console_handler.setFormatter(formatter)
31
- logger.addHandler(console_handler)
32
- return logger
33
-
34
-
35
- logger = get_logger('hy3dgen.shapgen')
36
-
37
-
38
- class synchronize_timer:
39
- """ Synchronized timer to count the inference time of `nn.Module.forward`.
40
-
41
- Supports both context manager and decorator usage.
42
-
43
- Example as context manager:
44
- ```python
45
- with synchronize_timer('name') as t:
46
- run()
47
- ```
48
-
49
- Example as decorator:
50
- ```python
51
- @synchronize_timer('Export to trimesh')
52
- def export_to_trimesh(mesh_output):
53
- pass
54
- ```
55
- """
56
-
57
- def __init__(self, name=None):
58
- self.name = name
59
-
60
- def __enter__(self):
61
- """Context manager entry: start timing."""
62
- if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
63
- self.start = torch.cuda.Event(enable_timing=True)
64
- self.end = torch.cuda.Event(enable_timing=True)
65
- self.start.record()
66
- return lambda: self.time
67
-
68
- def __exit__(self, exc_type, exc_value, exc_tb):
69
- """Context manager exit: stop timing and log results."""
70
- if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
71
- self.end.record()
72
- torch.cuda.synchronize()
73
- self.time = self.start.elapsed_time(self.end)
74
- if self.name is not None:
75
- logger.info(f'{self.name} takes {self.time} ms')
76
-
77
- def __call__(self, func):
78
- """Decorator: wrap the function to time its execution."""
79
-
80
- @wraps(func)
81
- def wrapper(*args, **kwargs):
82
- with self:
83
- result = func(*args, **kwargs)
84
- return result
85
-
86
- return wrapper
87
-
88
-
89
- def smart_load_model(
90
- model_path,
91
- subfolder,
92
- use_safetensors,
93
- variant,
94
- ):
95
- original_model_path = model_path
96
- # try local path
97
- base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
98
- model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
99
- model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
100
- logger.info(f'Try to load model from local path: {model_path}')
101
- if not os.path.exists(model_path):
102
- logger.info('Model path not exists, try to download from huggingface')
103
- try:
104
- from huggingface_hub import snapshot_download
105
- # 只下载指定子目录
106
- path = snapshot_download(
107
- repo_id=original_model_path,
108
- allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
109
- local_dir=model_fld
110
- )
111
- model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
112
- except ImportError:
113
- logger.warning(
114
- "You need to install HuggingFace Hub to load models from the hub."
115
- )
116
- raise RuntimeError(f"Model path {model_path} not found")
117
- except Exception as e:
118
- raise e
119
-
120
- if not os.path.exists(model_path):
121
- raise FileNotFoundError(f"Model path {original_model_path} not found")
122
-
123
- extension = 'ckpt' if not use_safetensors else 'safetensors'
124
- variant = '' if variant is None else f'.{variant}'
125
- ckpt_name = f'model{variant}.{extension}'
126
- config_path = os.path.join(model_path, 'config.yaml')
127
- ckpt_path = os.path.join(model_path, ckpt_name)
128
- return config_path, ckpt_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/visualizers/__init__.py DELETED
@@ -1 +0,0 @@
1
- # -*- coding: utf-8 -*-
 
 
ultrashape/utils/visualizers/color_util.py DELETED
@@ -1,57 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import numpy as np
16
- import matplotlib.pyplot as plt
17
-
18
-
19
- # Helper functions
20
- def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
21
- colormap = plt.cm.get_cmap(colormap)
22
- if normalize:
23
- vmin = np.min(inp)
24
- vmax = np.max(inp)
25
-
26
- norm = plt.Normalize(vmin, vmax)
27
- return colormap(norm(inp))[:, :3]
28
-
29
-
30
- def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
31
- # tex dims need to be power of two.
32
- array = np.ones((width, height, 3), dtype='float32')
33
-
34
- # width in texels of each checker
35
- checker_w = width / n_checkers_x
36
- checker_h = height / n_checkers_y
37
-
38
- for y in range(height):
39
- for x in range(width):
40
- color_key = int(x / checker_w) + int(y / checker_h)
41
- if color_key % 2 == 0:
42
- array[x, y, :] = [1., 0.874, 0.0]
43
- else:
44
- array[x, y, :] = [0., 0., 0.]
45
- return array
46
-
47
-
48
- def gen_circle(width=256, height=256):
49
- xx, yy = np.mgrid[:width, :height]
50
- circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
51
- array = np.ones((width, height, 4), dtype='float32')
52
- array[:, :, 0] = (circle <= width)
53
- array[:, :, 1] = (circle <= width)
54
- array[:, :, 2] = (circle <= width)
55
- array[:, :, 3] = circle <= width
56
- return array
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/visualizers/html_util.py DELETED
@@ -1,64 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
4
- # except for the third-party components listed below.
5
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
6
- # in the repsective licenses of these third-party components.
7
- # Users must comply with all terms and conditions of original licenses of these third-party
8
- # components and must ensure that the usage of the third party components adheres to
9
- # all relevant laws and regulations.
10
-
11
- # For avoidance of doubts, Hunyuan 3D means the large language models and
12
- # their software and algorithms, including trained model weights, parameters (including
13
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
14
- # fine-tuning enabling code and other elements of the foregoing made publicly available
15
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
16
-
17
- import io
18
- import base64
19
- import numpy as np
20
- from PIL import Image
21
-
22
-
23
- def to_html_frame(content):
24
-
25
- html_frame = f"""
26
- <html>
27
- <body>
28
- {content}
29
- </body>
30
- </html>
31
- """
32
-
33
- return html_frame
34
-
35
-
36
- def to_single_row_table(caption: str, content: str):
37
-
38
- table_html = f"""
39
- <table border = "1">
40
- <caption>{caption}</caption>
41
- <tr>
42
- <td>{content}</td>
43
- </tr>
44
- </table>
45
- """
46
-
47
- return table_html
48
-
49
-
50
- def to_image_embed_tag(image: np.ndarray):
51
-
52
- # Convert np.ndarray to bytes
53
- img = Image.fromarray(image)
54
- raw_bytes = io.BytesIO()
55
- img.save(raw_bytes, "PNG")
56
-
57
- # Encode bytes to base64
58
- image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
59
-
60
- image_tag = f"""
61
- <img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
62
- """
63
-
64
- return image_tag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/visualizers/pythreejs_viewer.py DELETED
@@ -1,549 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
-
16
- import numpy as np
17
- from ipywidgets import embed
18
- import pythreejs as p3s
19
- import uuid
20
-
21
- from .color_util import get_colors, gen_circle, gen_checkers
22
-
23
-
24
- EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
25
-
26
-
27
- class PyThreeJSViewer(object):
28
-
29
- def __init__(self, settings, render_mode="WEBSITE"):
30
- self.render_mode = render_mode
31
- self.__update_settings(settings)
32
- self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
33
- self._light2 = p3s.AmbientLight(intensity=0.5)
34
- self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
35
- aspect=self.__s["width"] / self.__s["height"], children=[self._light])
36
- self._orbit = p3s.OrbitControls(controlling=self._cam)
37
- self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
38
- self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
39
- width=self.__s["width"], height=self.__s["height"],
40
- antialias=self.__s["antialias"])
41
-
42
- self.__objects = {}
43
- self.__cnt = 0
44
-
45
- def jupyter_mode(self):
46
- self.render_mode = "JUPYTER"
47
-
48
- def offline(self):
49
- self.render_mode = "OFFLINE"
50
-
51
- def website(self):
52
- self.render_mode = "WEBSITE"
53
-
54
- def __get_shading(self, shading):
55
- shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
56
- "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
57
- "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
58
- "line_width": 1.0, "line_color": "black",
59
- "point_color": "red", "point_size": 0.01, "point_shape": "circle",
60
- "text_color": "red"
61
- }
62
- for k in shading:
63
- shad[k] = shading[k]
64
- return shad
65
-
66
- def __update_settings(self, settings={}):
67
- sett = {"width": 1600, "height": 800, "antialias": True, "scale": 1.5, "background": "#ffffff",
68
- "fov": 30}
69
- for k in settings:
70
- sett[k] = settings[k]
71
- self.__s = sett
72
-
73
- def __add_object(self, obj, parent=None):
74
- if not parent: # Object is added to global scene and objects dict
75
- self.__objects[self.__cnt] = obj
76
- self.__cnt += 1
77
- self._scene.add(obj["mesh"])
78
- else: # Object is added to parent object and NOT to objects dict
79
- parent.add(obj["mesh"])
80
-
81
- self.__update_view()
82
-
83
- if self.render_mode == "JUPYTER":
84
- return self.__cnt - 1
85
- elif self.render_mode == "WEBSITE":
86
- return self
87
-
88
- def __add_line_geometry(self, lines, shading, obj=None):
89
- lines = lines.astype("float32", copy=False)
90
- mi = np.min(lines, axis=0)
91
- ma = np.max(lines, axis=0)
92
-
93
- geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
94
- material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
95
- # , vertexColors='VertexColors'),
96
- lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
97
- line_obj = {"geometry": geometry, "mesh": lines, "material": material,
98
- "max": ma, "min": mi, "type": "Lines", "wireframe": None}
99
-
100
- if obj:
101
- return self.__add_object(line_obj, obj), line_obj
102
- else:
103
- return self.__add_object(line_obj)
104
-
105
- def __update_view(self):
106
- if len(self.__objects) == 0:
107
- return
108
- ma = np.zeros((len(self.__objects), 3))
109
- mi = np.zeros((len(self.__objects), 3))
110
- for r, obj in enumerate(self.__objects):
111
- ma[r] = self.__objects[obj]["max"]
112
- mi[r] = self.__objects[obj]["min"]
113
- ma = np.max(ma, axis=0)
114
- mi = np.min(mi, axis=0)
115
- diag = np.linalg.norm(ma - mi)
116
- mean = ((ma - mi) / 2 + mi).tolist()
117
- scale = self.__s["scale"] * (diag)
118
- self._orbit.target = mean
119
- self._cam.lookAt(mean)
120
- self._cam.position = [mean[0], mean[1], mean[2] + scale]
121
- self._light.position = [mean[0], mean[1], mean[2] + scale]
122
-
123
- self._orbit.exec_three_obj_method('update')
124
- self._cam.exec_three_obj_method('updateProjectionMatrix')
125
-
126
- def __get_bbox(self, v):
127
- m = np.min(v, axis=0)
128
- M = np.max(v, axis=0)
129
-
130
- # Corners of the bounding box
131
- v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
132
- [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
133
-
134
- f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
135
- [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
136
- return v_box, f_box
137
-
138
- def __get_colors(self, v, f, c, sh):
139
- coloring = "VertexColors"
140
- if type(c) == np.ndarray and c.size == 3: # Single color
141
- colors = np.ones_like(v)
142
- colors[:, 0] = c[0]
143
- colors[:, 1] = c[1]
144
- colors[:, 2] = c[2]
145
- # print("Single colors")
146
- elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
147
- if c.shape[0] == f.shape[0]: # faces
148
- colors = np.hstack([c, c, c]).reshape((-1, 3))
149
- coloring = "FaceColors"
150
- # print("Face color values")
151
- elif c.shape[0] == v.shape[0]: # vertices
152
- colors = c
153
- # print("Vertex color values")
154
- else: # Wrong size, fallback
155
- print("Invalid color array given! Supported are numpy arrays.", type(c))
156
- colors = np.ones_like(v)
157
- colors[:, 0] = 1.0
158
- colors[:, 1] = 0.874
159
- colors[:, 2] = 0.0
160
- elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
161
- normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
162
- cc = get_colors(c, sh["colormap"], normalize=normalize,
163
- vmin=sh["normalize"][0], vmax=sh["normalize"][1])
164
- # print(cc.shape)
165
- colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
166
- coloring = "FaceColors"
167
- # print("Face function values")
168
- elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
169
- normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
170
- colors = get_colors(c, sh["colormap"], normalize=normalize,
171
- vmin=sh["normalize"][0], vmax=sh["normalize"][1])
172
- # print("Vertex function values")
173
-
174
- else:
175
- colors = np.ones_like(v)
176
- colors[:, 0] = 1.0
177
- colors[:, 1] = 0.874
178
- colors[:, 2] = 0.0
179
-
180
- # No color
181
- if c is not None:
182
- print("Invalid color array given! Supported are numpy arrays.", type(c))
183
-
184
- return colors, coloring
185
-
186
- def __get_point_colors(self, v, c, sh):
187
- v_color = True
188
- if c is None: # No color given, use global color
189
- # conv = mpl.colors.ColorConverter()
190
- colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
191
- v_color = False
192
- elif isinstance(c, str): # No color given, use global color
193
- # conv = mpl.colors.ColorConverter()
194
- colors = c # np.array(conv.to_rgb(c))
195
- v_color = False
196
- elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
197
- # Point color
198
- colors = c.astype("float32", copy=False)
199
-
200
- elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
201
- # Function values for vertices, but the colors are features
202
- c_norm = np.linalg.norm(c, ord=2, axis=-1)
203
- normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
204
- colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
205
- vmin=sh["normalize"][0], vmax=sh["normalize"][1])
206
- colors = colors.astype("float32", copy=False)
207
-
208
- elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
209
- normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
210
- colors = get_colors(c, sh["colormap"], normalize=normalize,
211
- vmin=sh["normalize"][0], vmax=sh["normalize"][1])
212
- colors = colors.astype("float32", copy=False)
213
- # print("Vertex function values")
214
-
215
- else:
216
- print("Invalid color array given! Supported are numpy arrays.", type(c))
217
- colors = sh["point_color"]
218
- v_color = False
219
-
220
- return colors, v_color
221
-
222
- def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
223
- shading.update(kwargs)
224
- sh = self.__get_shading(shading)
225
- mesh_obj = {}
226
-
227
- # it is a tet
228
- if v.shape[1] == 3 and f.shape[1] == 4:
229
- f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
230
- for i in range(f.shape[0]):
231
- f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
232
- f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
233
- f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
234
- f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
235
- f = f_tmp
236
-
237
- if v.shape[1] == 2:
238
- v = np.append(v, np.zeros([v.shape[0], 1]), 1)
239
-
240
- # Type adjustment vertices
241
- v = v.astype("float32", copy=False)
242
-
243
- # Color setup
244
- colors, coloring = self.__get_colors(v, f, c, sh)
245
-
246
- # Type adjustment faces and colors
247
- c = colors.astype("float32", copy=False)
248
-
249
- # Material and geometry setup
250
- ba_dict = {"color": p3s.BufferAttribute(c)}
251
- if coloring == "FaceColors":
252
- verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
253
- for ii in range(f.shape[0]):
254
- # print(ii*3, f[ii])
255
- verts[ii * 3] = v[f[ii, 0]]
256
- verts[ii * 3 + 1] = v[f[ii, 1]]
257
- verts[ii * 3 + 2] = v[f[ii, 2]]
258
- v = verts
259
- else:
260
- f = f.astype("uint32", copy=False).ravel()
261
- ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
262
-
263
- ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
264
-
265
- if uv is not None:
266
- uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
267
- if texture_data is None:
268
- texture_data = gen_checkers(20, 20)
269
- tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
270
- material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
271
- roughness=sh["roughness"], metalness=sh["metalness"],
272
- flatShading=sh["flat"],
273
- polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
274
- ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
275
- else:
276
- material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
277
- side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
278
- flatShading=sh["flat"],
279
- polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
280
-
281
- if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
282
- ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
283
-
284
- geometry = p3s.BufferGeometry(attributes=ba_dict)
285
-
286
- if coloring == "VertexColors" and type(n) == type(None):
287
- geometry.exec_three_obj_method('computeVertexNormals')
288
- elif coloring == "FaceColors" and type(n) == type(None):
289
- geometry.exec_three_obj_method('computeFaceNormals')
290
-
291
- # Mesh setup
292
- mesh = p3s.Mesh(geometry=geometry, material=material)
293
-
294
- # Wireframe setup
295
- mesh_obj["wireframe"] = None
296
- if sh["wireframe"]:
297
- wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
298
- wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
299
- wireframe = p3s.LineSegments(wf_geometry, wf_material)
300
- mesh.add(wireframe)
301
- mesh_obj["wireframe"] = wireframe
302
-
303
- # Bounding box setup
304
- if sh["bbox"]:
305
- v_box, f_box = self.__get_bbox(v)
306
- _, bbox = self.add_edges(v_box, f_box, sh, mesh)
307
- mesh_obj["bbox"] = [bbox, v_box, f_box]
308
-
309
- # Object setup
310
- mesh_obj["max"] = np.max(v, axis=0)
311
- mesh_obj["min"] = np.min(v, axis=0)
312
- mesh_obj["geometry"] = geometry
313
- mesh_obj["mesh"] = mesh
314
- mesh_obj["material"] = material
315
- mesh_obj["type"] = "Mesh"
316
- mesh_obj["shading"] = sh
317
- mesh_obj["coloring"] = coloring
318
- mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
319
-
320
- return self.__add_object(mesh_obj)
321
-
322
- def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
323
- shading.update(kwargs)
324
- if len(beginning.shape) == 1:
325
- if len(beginning) == 2:
326
- beginning = np.array([[beginning[0], beginning[1], 0]])
327
- else:
328
- if beginning.shape[1] == 2:
329
- beginning = np.append(
330
- beginning, np.zeros([beginning.shape[0], 1]), 1)
331
- if len(ending.shape) == 1:
332
- if len(ending) == 2:
333
- ending = np.array([[ending[0], ending[1], 0]])
334
- else:
335
- if ending.shape[1] == 2:
336
- ending = np.append(
337
- ending, np.zeros([ending.shape[0], 1]), 1)
338
-
339
- sh = self.__get_shading(shading)
340
- lines = np.hstack([beginning, ending])
341
- lines = lines.reshape((-1, 3))
342
- return self.__add_line_geometry(lines, sh, obj)
343
-
344
- def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
345
- shading.update(kwargs)
346
- if vertices.shape[1] == 2:
347
- vertices = np.append(
348
- vertices, np.zeros([vertices.shape[0], 1]), 1)
349
- sh = self.__get_shading(shading)
350
- lines = np.zeros((edges.size, 3))
351
- cnt = 0
352
- for e in edges:
353
- lines[cnt, :] = vertices[e[0]]
354
- lines[cnt + 1, :] = vertices[e[1]]
355
- cnt += 2
356
- return self.__add_line_geometry(lines, sh, obj)
357
-
358
- def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
359
- shading.update(kwargs)
360
- if len(points.shape) == 1:
361
- if len(points) == 2:
362
- points = np.array([[points[0], points[1], 0]])
363
- else:
364
- if points.shape[1] == 2:
365
- points = np.append(
366
- points, np.zeros([points.shape[0], 1]), 1)
367
- sh = self.__get_shading(shading)
368
- points = points.astype("float32", copy=False)
369
- mi = np.min(points, axis=0)
370
- ma = np.max(points, axis=0)
371
-
372
- g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
373
- m_attributes = {"size": sh["point_size"]}
374
-
375
- if sh["point_shape"] == "circle": # Plot circles
376
- tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
377
- m_attributes["map"] = tex
378
- m_attributes["alphaTest"] = 0.5
379
- m_attributes["transparency"] = True
380
- else: # Plot squares
381
- pass
382
-
383
- colors, v_colors = self.__get_point_colors(points, c, sh)
384
- if v_colors: # Colors per point
385
- m_attributes["vertexColors"] = 'VertexColors'
386
- g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
387
-
388
- else: # Colors for all points
389
- m_attributes["color"] = colors
390
-
391
- material = p3s.PointsMaterial(**m_attributes)
392
- geometry = p3s.BufferGeometry(attributes=g_attributes)
393
- points = p3s.Points(geometry=geometry, material=material)
394
- point_obj = {"geometry": geometry, "mesh": points, "material": material,
395
- "max": ma, "min": mi, "type": "Points", "wireframe": None}
396
-
397
- if obj:
398
- return self.__add_object(point_obj, obj), point_obj
399
- else:
400
- return self.__add_object(point_obj)
401
-
402
- def remove_object(self, obj_id):
403
- if obj_id not in self.__objects:
404
- print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
405
- return
406
- self._scene.remove(self.__objects[obj_id]["mesh"])
407
- del self.__objects[obj_id]
408
- self.__update_view()
409
-
410
- def reset(self):
411
- for obj_id in list(self.__objects.keys()).copy():
412
- self._scene.remove(self.__objects[obj_id]["mesh"])
413
- del self.__objects[obj_id]
414
- self.__update_view()
415
-
416
- def update_object(self, oid=0, vertices=None, colors=None, faces=None):
417
- obj = self.__objects[oid]
418
- if type(vertices) != type(None):
419
- if obj["coloring"] == "FaceColors":
420
- f = obj["arrays"][1]
421
- verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
422
- for ii in range(f.shape[0]):
423
- # print(ii*3, f[ii])
424
- verts[ii * 3] = vertices[f[ii, 0]]
425
- verts[ii * 3 + 1] = vertices[f[ii, 1]]
426
- verts[ii * 3 + 2] = vertices[f[ii, 2]]
427
- v = verts
428
-
429
- else:
430
- v = vertices.astype("float32", copy=False)
431
- obj["geometry"].attributes["position"].array = v
432
- # self.wireframe.attributes["position"].array = v # Wireframe updates?
433
- obj["geometry"].attributes["position"].needsUpdate = True
434
- # obj["geometry"].exec_three_obj_method('computeVertexNormals')
435
- if type(colors) != type(None):
436
- colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
437
- colors = colors.astype("float32", copy=False)
438
- obj["geometry"].attributes["color"].array = colors
439
- obj["geometry"].attributes["color"].needsUpdate = True
440
- if type(faces) != type(None):
441
- if obj["coloring"] == "FaceColors":
442
- print("Face updates are currently only possible in vertex color mode.")
443
- return
444
- f = faces.astype("uint32", copy=False).ravel()
445
- print(obj["geometry"].attributes)
446
- obj["geometry"].attributes["index"].array = f
447
- # self.wireframe.attributes["position"].array = v # Wireframe updates?
448
- obj["geometry"].attributes["index"].needsUpdate = True
449
- # obj["geometry"].exec_three_obj_method('computeVertexNormals')
450
- # self.mesh.geometry.verticesNeedUpdate = True
451
- # self.mesh.geometry.elementsNeedUpdate = True
452
- # self.update()
453
- if self.render_mode == "WEBSITE":
454
- return self
455
-
456
- # def update(self):
457
- # self.mesh.exec_three_obj_method('update')
458
- # self.orbit.exec_three_obj_method('update')
459
- # self.cam.exec_three_obj_method('updateProjectionMatrix')
460
- # self.scene.exec_three_obj_method('update')
461
-
462
- def add_text(self, text, shading={}, **kwargs):
463
- shading.update(kwargs)
464
- sh = self.__get_shading(shading)
465
- tt = p3s.TextTexture(string=text, color=sh["text_color"])
466
- sm = p3s.SpriteMaterial(map=tt)
467
- text = p3s.Sprite(material=sm, scaleToTexture=True)
468
- self._scene.add(text)
469
-
470
- # def add_widget(self, widget, callback):
471
- # self.widgets.append(widget)
472
- # widget.observe(callback, names='value')
473
-
474
- # def add_dropdown(self, options, default, desc, cb):
475
- # widget = widgets.Dropdown(options=options, value=default, description=desc)
476
- # self.__widgets.append(widget)
477
- # widget.observe(cb, names="value")
478
- # display(widget)
479
-
480
- # def add_button(self, text, cb):
481
- # button = widgets.Button(description=text)
482
- # self.__widgets.append(button)
483
- # button.on_click(cb)
484
- # display(button)
485
-
486
- def to_html(self, imports=True, html_frame=True):
487
- # Bake positions (fixes centering bug in offline rendering)
488
- if len(self.__objects) == 0:
489
- return
490
- ma = np.zeros((len(self.__objects), 3))
491
- mi = np.zeros((len(self.__objects), 3))
492
- for r, obj in enumerate(self.__objects):
493
- ma[r] = self.__objects[obj]["max"]
494
- mi[r] = self.__objects[obj]["min"]
495
- ma = np.max(ma, axis=0)
496
- mi = np.min(mi, axis=0)
497
- diag = np.linalg.norm(ma - mi)
498
- mean = (ma - mi) / 2 + mi
499
- for r, obj in enumerate(self.__objects):
500
- v = self.__objects[obj]["geometry"].attributes["position"].array
501
- v -= mean
502
- # v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
503
-
504
- scale = self.__s["scale"] * (diag)
505
- self._orbit.target = [0.0, 0.0, 0.0]
506
- self._cam.lookAt([0.0, 0.0, 0.0])
507
- # self._cam.position = [0.0, 0.0, scale]
508
- self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
509
- self._light.position = [0.0, 0.0, scale]
510
-
511
- state = embed.dependency_state(self._renderer)
512
-
513
- # Somehow these entries are missing when the state is exported in python.
514
- # Exporting from the GUI works, so we are inserting the missing entries.
515
- for k in state:
516
- if state[k]["model_name"] == "OrbitControlsModel":
517
- state[k]["state"]["maxAzimuthAngle"] = "inf"
518
- state[k]["state"]["maxDistance"] = "inf"
519
- state[k]["state"]["maxZoom"] = "inf"
520
- state[k]["state"]["minAzimuthAngle"] = "-inf"
521
-
522
- tpl = embed.load_requirejs_template
523
- if not imports:
524
- embed.load_requirejs_template = ""
525
-
526
- s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
527
- # s = embed.embed_snippet(self.__w, state=state)
528
- embed.load_requirejs_template = tpl
529
-
530
- if html_frame:
531
- s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
532
-
533
- # Revert changes
534
- for r, obj in enumerate(self.__objects):
535
- v = self.__objects[obj]["geometry"].attributes["position"].array
536
- v += mean
537
- self.__update_view()
538
-
539
- return s
540
-
541
- def save(self, filename=""):
542
- if filename == "":
543
- uid = str(uuid.uuid4()) + ".html"
544
- else:
545
- filename = filename.replace(".html", "")
546
- uid = filename + '.html'
547
- with open(uid, "w") as f:
548
- f.write(self.to_html())
549
- print("Plot saved to file %s." % uid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ultrashape/utils/voxelize.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
-
3
- def voxelize_from_point(pc, num_latents, resolution=128):
4
-
5
- B, N, D = pc.shape
6
- device = pc.device
7
-
8
- norm_pc = (pc + 1.0) / 2.0
9
- voxel_indices = torch.floor(norm_pc * resolution).long()
10
- voxel_indices = torch.clamp(voxel_indices, 0, resolution - 1) # (B, N, 3)
11
-
12
- batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
13
- flat_indices = torch.cat([batch_idx.unsqueeze(-1), voxel_indices], dim=-1).view(-1, 4)
14
- unique_voxels = torch.unique(flat_indices, dim=0)
15
- u_batch_ids = unique_voxels[:, 0]
16
-
17
- noise = torch.rand_like(u_batch_ids, dtype=torch.float)
18
- sort_keys = u_batch_ids.float() + noise
19
- perm = torch.argsort(sort_keys)
20
- shuffled_voxels = unique_voxels[perm]
21
- shuffled_batch_ids = shuffled_voxels[:, 0].contiguous()
22
-
23
- counts = torch.bincount(shuffled_batch_ids, minlength=B)
24
- min_count = counts.min().item()
25
-
26
- # Always aim for num_latents
27
- actual_k = num_latents
28
-
29
- if min_count < num_latents:
30
- print(f"[Info] Voxel count ({min_count}) < Target ({num_latents}). Sampling with replacement.")
31
- # If we don't have enough unique voxels, we need to sample with replacement/repetition
32
- # We can just repeat the indices to fill the gap
33
-
34
- batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))
35
-
36
- # Create gathering indices that wrap around for each batch
37
- # For each batch element i, we want actual_k indices
38
- # They start at batch_starts[i] and go up to batch_starts[i] + counts[i]
39
- # We use modulo to wrap around: (j % counts[i]) + batch_starts[i]
40
-
41
- # Expand for broadcasting
42
- batch_starts_exp = batch_starts.unsqueeze(1) # [B, 1]
43
- counts_exp = counts.unsqueeze(1) # [B, 1]
44
-
45
- offsets = torch.arange(actual_k, device=device).unsqueeze(0) # [1, K]
46
-
47
- # Calculate offsets modulo the available count for each batch
48
- # This effectively repeats the available voxels to fill the desired size
49
- # We need to be careful about division by zero if a batch has 0 voxels (shouldn't happen with valid PC)
50
- counts_exp = torch.maximum(counts_exp, torch.tensor(1, device=device))
51
-
52
- wrapped_offsets = offsets % counts_exp
53
- gather_indices = batch_starts_exp + wrapped_offsets
54
- gather_indices = gather_indices.view(-1)
55
-
56
- else:
57
- # Standard case: enough points, just take the first k
58
- batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))
59
- offsets = torch.arange(actual_k, device=device).unsqueeze(0)
60
- gather_indices = batch_starts.unsqueeze(1) + offsets
61
- gather_indices = gather_indices.view(-1)
62
-
63
- selected_indices = shuffled_voxels[gather_indices]
64
-
65
- final_grid_coords = selected_indices[:, 1:]
66
-
67
- # Grid Index -> Voxel Center
68
- voxel_size = 2.0 / resolution
69
- final_centers = (final_grid_coords.float() + 0.5) * voxel_size - 1.0
70
-
71
- sampled_pc = final_centers.view(B, actual_k, 3)
72
- sampled_indices = final_grid_coords.view(B, actual_k, 3)
73
-
74
- return sampled_pc, sampled_indices