Spaces:
Running on Zero
Running on Zero
Delete ultrashape
Browse files- ultrashape/__init__.py +0 -17
- ultrashape/data/objaverse_dit.py +0 -331
- ultrashape/data/objaverse_vae.py +0 -262
- ultrashape/data/utils.py +0 -193
- ultrashape/models/__init__.py +0 -27
- ultrashape/models/autoencoders/__init__.py +0 -21
- ultrashape/models/autoencoders/attention_blocks.py +0 -711
- ultrashape/models/autoencoders/attention_processors.py +0 -103
- ultrashape/models/autoencoders/model.py +0 -377
- ultrashape/models/autoencoders/surface_extractors.py +0 -266
- ultrashape/models/autoencoders/vae_trainer.py +0 -229
- ultrashape/models/autoencoders/volume_decoders.py +0 -440
- ultrashape/models/conditioner_mask.py +0 -337
- ultrashape/models/denoisers/__init__.py +0 -22
- ultrashape/models/denoisers/dit_mask.py +0 -725
- ultrashape/models/denoisers/moe_layers.py +0 -177
- ultrashape/models/diffusion/flow_matching_dit_trainer.py +0 -313
- ultrashape/models/diffusion/transport/__init__.py +0 -97
- ultrashape/models/diffusion/transport/integrators.py +0 -142
- ultrashape/models/diffusion/transport/path.py +0 -220
- ultrashape/models/diffusion/transport/transport.py +0 -534
- ultrashape/models/diffusion/transport/utils.py +0 -54
- ultrashape/pipelines.py +0 -797
- ultrashape/postprocessors.py +0 -209
- ultrashape/preprocessors.py +0 -167
- ultrashape/rembg.py +0 -32
- ultrashape/schedulers.py +0 -480
- ultrashape/surface_loaders.py +0 -233
- ultrashape/utils/__init__.py +0 -6
- ultrashape/utils/ema.py +0 -76
- ultrashape/utils/misc.py +0 -200
- ultrashape/utils/trainings/__init__.py +0 -1
- ultrashape/utils/trainings/callback.py +0 -213
- ultrashape/utils/trainings/lr_scheduler.py +0 -53
- ultrashape/utils/trainings/mesh.py +0 -128
- ultrashape/utils/trainings/mesh_log_callback.py +0 -342
- ultrashape/utils/trainings/peft.py +0 -78
- ultrashape/utils/typing.py +0 -41
- ultrashape/utils/utils.py +0 -128
- ultrashape/utils/visualizers/__init__.py +0 -1
- ultrashape/utils/visualizers/color_util.py +0 -57
- ultrashape/utils/visualizers/html_util.py +0 -64
- ultrashape/utils/visualizers/pythreejs_viewer.py +0 -549
- ultrashape/utils/voxelize.py +0 -74
ultrashape/__init__.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
from .pipelines import UltraShapePipeline
|
| 16 |
-
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
|
| 17 |
-
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/data/objaverse_dit.py
DELETED
|
@@ -1,331 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# ==============================================================================
|
| 4 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 5 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 6 |
-
#
|
| 7 |
-
# Modified by UltraShape on 2025.12.25
|
| 8 |
-
# ==============================================================================
|
| 9 |
-
|
| 10 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 11 |
-
# except for the third-party components listed below.
|
| 12 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 13 |
-
# in the repsective licenses of these third-party components.
|
| 14 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 15 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 16 |
-
# all relevant laws and regulations.
|
| 17 |
-
|
| 18 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 19 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 20 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 21 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 22 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 23 |
-
|
| 24 |
-
import math
|
| 25 |
-
import os
|
| 26 |
-
import json
|
| 27 |
-
from dataclasses import dataclass, field
|
| 28 |
-
|
| 29 |
-
import random
|
| 30 |
-
import imageio
|
| 31 |
-
import numpy as np
|
| 32 |
-
import pytorch_lightning as pl
|
| 33 |
-
import torch
|
| 34 |
-
import torch.nn.functional as F
|
| 35 |
-
from torch.utils.data import DataLoader, Dataset
|
| 36 |
-
from PIL import Image
|
| 37 |
-
import pickle
|
| 38 |
-
from ultrashape.utils.typing import *
|
| 39 |
-
import pandas as pd
|
| 40 |
-
import cv2
|
| 41 |
-
import torchvision.transforms as transforms
|
| 42 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 43 |
-
|
| 44 |
-
def padding(image, mask, center=True, padding_ratio_range=[1.15, 1.15]):
|
| 45 |
-
"""
|
| 46 |
-
Pad the input image and mask to a square shape with padding ratio.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
image (np.ndarray): Input image array of shape (H, W, C).
|
| 50 |
-
mask (np.ndarray): Corresponding mask array of shape (H, W).
|
| 51 |
-
center (bool): Whether to center the original image in the padded output.
|
| 52 |
-
padding_ratio_range (list): Range [min, max] to randomly select padding ratio.
|
| 53 |
-
|
| 54 |
-
Returns:
|
| 55 |
-
newimg (np.ndarray): Padded image of shape (resize_side, resize_side, 3).
|
| 56 |
-
newmask (np.ndarray): Padded mask of shape (resize_side, resize_side).
|
| 57 |
-
"""
|
| 58 |
-
h, w = image.shape[:2]
|
| 59 |
-
max_side = max(h, w)
|
| 60 |
-
|
| 61 |
-
# Select padding ratio either fixed or randomly within the given range
|
| 62 |
-
if padding_ratio_range[0] == padding_ratio_range[1]:
|
| 63 |
-
padding_ratio = padding_ratio_range[0]
|
| 64 |
-
else:
|
| 65 |
-
padding_ratio = random.uniform(padding_ratio_range[0], padding_ratio_range[1])
|
| 66 |
-
resize_side = int(max_side * padding_ratio)
|
| 67 |
-
|
| 68 |
-
pad_h = resize_side - h
|
| 69 |
-
pad_w = resize_side - w
|
| 70 |
-
if center:
|
| 71 |
-
start_h = pad_h // 2
|
| 72 |
-
else:
|
| 73 |
-
start_h = pad_h - resize_side // 20
|
| 74 |
-
|
| 75 |
-
start_w = pad_w // 2
|
| 76 |
-
|
| 77 |
-
# Create new white image and black mask with padded size
|
| 78 |
-
newimg = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255
|
| 79 |
-
newmask = np.zeros((resize_side, resize_side), dtype=np.uint8)
|
| 80 |
-
|
| 81 |
-
# Place original image and mask into the padded canvas
|
| 82 |
-
newimg[start_h:start_h + h, start_w:start_w + w] = image
|
| 83 |
-
newmask[start_h:start_h + h, start_w:start_w + w] = mask
|
| 84 |
-
|
| 85 |
-
return newimg, newmask
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
class ObjaverseDataset(Dataset):
|
| 89 |
-
def __init__(
|
| 90 |
-
self,
|
| 91 |
-
data_json,
|
| 92 |
-
sample_root,
|
| 93 |
-
image_path,
|
| 94 |
-
image_transform = None,
|
| 95 |
-
pc_size: int = 2048,
|
| 96 |
-
pc_sharpedge_size: int = 2048,
|
| 97 |
-
sharpedge_label: bool = False,
|
| 98 |
-
return_normal: bool = False,
|
| 99 |
-
padding = True,
|
| 100 |
-
padding_ratio_range=[1.15, 1.15],
|
| 101 |
-
):
|
| 102 |
-
super().__init__()
|
| 103 |
-
|
| 104 |
-
self.uids = json.load(open(data_json))
|
| 105 |
-
self.sample_root = sample_root
|
| 106 |
-
self.image_paths = json.load(open(image_path))
|
| 107 |
-
self.image_transform = image_transform
|
| 108 |
-
|
| 109 |
-
self.pc_size = pc_size
|
| 110 |
-
self.pc_sharpedge_size = pc_sharpedge_size
|
| 111 |
-
self.sharpedge_label = sharpedge_label
|
| 112 |
-
self.return_normal = return_normal
|
| 113 |
-
|
| 114 |
-
self.padding = padding
|
| 115 |
-
self.padding_ratio_range = padding_ratio_range
|
| 116 |
-
|
| 117 |
-
print(f"Loaded {len(self.uids)} uids from {data_json}.")
|
| 118 |
-
|
| 119 |
-
rank_zero_info(f'*' * 50)
|
| 120 |
-
rank_zero_info(f'Dataset Infos:')
|
| 121 |
-
rank_zero_info(f'# of 3D file: {len(self.uids)}')
|
| 122 |
-
rank_zero_info(f'# of Surface Points: {self.pc_size}')
|
| 123 |
-
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
|
| 124 |
-
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
|
| 125 |
-
rank_zero_info(f'*' * 50)
|
| 126 |
-
|
| 127 |
-
def __len__(self):
|
| 128 |
-
return len(self.uids)
|
| 129 |
-
|
| 130 |
-
def _load_shape(self, index: int) -> Dict[str, Any]:
|
| 131 |
-
|
| 132 |
-
data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
|
| 133 |
-
|
| 134 |
-
surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
|
| 135 |
-
normal = np.asarray(data['clean_surface_normals'])
|
| 136 |
-
surface_og_n = np.concatenate([surface_og, normal], axis=1)
|
| 137 |
-
rng = np.random.default_rng()
|
| 138 |
-
|
| 139 |
-
# hard code: first 300k are uniform, last 300k are sharp
|
| 140 |
-
assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
|
| 141 |
-
coarse_surface = surface_og_n[:300000]
|
| 142 |
-
sharp_surface = surface_og_n[300000:]
|
| 143 |
-
|
| 144 |
-
surface_normal = []
|
| 145 |
-
rng = np.random.default_rng()
|
| 146 |
-
if self.pc_size > 0:
|
| 147 |
-
ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
|
| 148 |
-
coarse_surface = coarse_surface[ind]
|
| 149 |
-
if self.sharpedge_label:
|
| 150 |
-
sharpedge_label = np.zeros((self.pc_size // 2, 1))
|
| 151 |
-
coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
|
| 152 |
-
surface_normal.append(coarse_surface)
|
| 153 |
-
|
| 154 |
-
ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
|
| 155 |
-
sharp_surface = sharp_surface[ind_sharpedge]
|
| 156 |
-
if self.sharpedge_label:
|
| 157 |
-
sharpedge_label = np.ones((self.pc_size // 2, 1))
|
| 158 |
-
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
|
| 159 |
-
surface_normal.append(sharp_surface)
|
| 160 |
-
|
| 161 |
-
surface_normal = np.concatenate(surface_normal, axis=0)
|
| 162 |
-
surface_normal = torch.FloatTensor(surface_normal)
|
| 163 |
-
surface = surface_normal[:, 0:3]
|
| 164 |
-
normal = surface_normal[:, 3:6]
|
| 165 |
-
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
|
| 166 |
-
|
| 167 |
-
geo_points = 0.0
|
| 168 |
-
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
|
| 169 |
-
if self.return_normal:
|
| 170 |
-
surface = torch.cat([surface, normal], dim=-1)
|
| 171 |
-
if self.sharpedge_label:
|
| 172 |
-
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
|
| 173 |
-
|
| 174 |
-
ret = {
|
| 175 |
-
"uid": self.uids[index],
|
| 176 |
-
"surface": surface,
|
| 177 |
-
"geo_points": geo_points
|
| 178 |
-
}
|
| 179 |
-
return ret
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def _load_image(self, index: int) -> Dict[str, Any]:
|
| 183 |
-
ret = {}
|
| 184 |
-
sel_idx = random.randint(0, 15)
|
| 185 |
-
ret["sel_image_idx"] = sel_idx
|
| 186 |
-
obj_name = self.uids[index]
|
| 187 |
-
img_path = f'{self.image_paths[obj_name]}/{os.path.basename(self.image_paths[obj_name])}/rgba/' + f"{sel_idx:03d}.png"
|
| 188 |
-
|
| 189 |
-
images, masks = [], []
|
| 190 |
-
image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
| 191 |
-
assert image.shape[2] == 4
|
| 192 |
-
alpha = image[:, :, 3:4].astype(np.float32) / 255
|
| 193 |
-
forground = image[:, :, :3]
|
| 194 |
-
background = np.ones_like(forground) * 255
|
| 195 |
-
img_new = forground * alpha + background * (1 - alpha)
|
| 196 |
-
image = img_new.astype(np.uint8)
|
| 197 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 198 |
-
mask = (alpha[:, :, 0] * 255).astype(np.uint8)
|
| 199 |
-
|
| 200 |
-
if self.padding:
|
| 201 |
-
h, w = image.shape[:2]
|
| 202 |
-
binary = mask > 0.3
|
| 203 |
-
non_zero_coords = np.argwhere(binary)
|
| 204 |
-
x_min, y_min = non_zero_coords.min(axis=0)
|
| 205 |
-
x_max, y_max = non_zero_coords.max(axis=0)
|
| 206 |
-
image, mask = padding(
|
| 207 |
-
image[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
| 208 |
-
mask[max(x_min - 5, 0):min(x_max + 5, h), max(y_min - 5, 0):min(y_max + 5, w)],
|
| 209 |
-
center=True, padding_ratio_range=self.padding_ratio_range)
|
| 210 |
-
|
| 211 |
-
if self.image_transform:
|
| 212 |
-
image = self.image_transform(image)
|
| 213 |
-
mask = np.stack((mask, mask, mask), axis=-1)
|
| 214 |
-
mask = self.image_transform(mask)
|
| 215 |
-
|
| 216 |
-
images.append(image)
|
| 217 |
-
masks.append(mask)
|
| 218 |
-
ret["image"] = torch.cat(images, dim=0)
|
| 219 |
-
ret["mask"] = torch.cat(masks, dim=0)[:1, ...]
|
| 220 |
-
|
| 221 |
-
return ret
|
| 222 |
-
|
| 223 |
-
def get_data(self, index):
|
| 224 |
-
ret = self._load_shape(index)
|
| 225 |
-
ret.update(self._load_image(index))
|
| 226 |
-
return ret
|
| 227 |
-
|
| 228 |
-
def __getitem__(self, index):
|
| 229 |
-
try:
|
| 230 |
-
return self.get_data(index)
|
| 231 |
-
except Exception as e:
|
| 232 |
-
print(f"Error in {self.uids[index]}: {e}")
|
| 233 |
-
return self.__getitem__(np.random.randint(len(self)))
|
| 234 |
-
|
| 235 |
-
def collate(self, batch):
|
| 236 |
-
batch = torch.utils.data.default_collate(batch)
|
| 237 |
-
return batch
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
class ObjaverseDataModule(pl.LightningDataModule):
|
| 241 |
-
def __init__(
|
| 242 |
-
self,
|
| 243 |
-
batch_size: int = 1,
|
| 244 |
-
num_workers: int = 4,
|
| 245 |
-
val_num_workers: int = 2,
|
| 246 |
-
training_data_list: str = None,
|
| 247 |
-
sample_pcd_dir: str = None,
|
| 248 |
-
image_data_json: str = None,
|
| 249 |
-
image_size: int = 224,
|
| 250 |
-
mean: Union[List[float], Tuple[float]] = (0.485, 0.456, 0.406),
|
| 251 |
-
std: Union[List[float], Tuple[float]] = (0.229, 0.224, 0.225),
|
| 252 |
-
pc_size: int = 2048,
|
| 253 |
-
pc_sharpedge_size: int = 2048,
|
| 254 |
-
sharpedge_label: bool = False,
|
| 255 |
-
return_normal: bool = False,
|
| 256 |
-
padding = True,
|
| 257 |
-
padding_ratio_range=[1.15, 1.15]
|
| 258 |
-
):
|
| 259 |
-
|
| 260 |
-
super().__init__()
|
| 261 |
-
self.batch_size = batch_size
|
| 262 |
-
self.num_workers = num_workers
|
| 263 |
-
self.val_num_workers = val_num_workers
|
| 264 |
-
|
| 265 |
-
self.training_data_list = training_data_list
|
| 266 |
-
self.sample_pcd_dir = sample_pcd_dir
|
| 267 |
-
self.image_data_json = image_data_json
|
| 268 |
-
|
| 269 |
-
self.image_size = image_size
|
| 270 |
-
self.mean = mean
|
| 271 |
-
self.std = std
|
| 272 |
-
self.train_image_transform = transforms.Compose([
|
| 273 |
-
transforms.ToTensor(),
|
| 274 |
-
transforms.Resize(self.image_size),
|
| 275 |
-
transforms.Normalize(mean=self.mean, std=self.std)])
|
| 276 |
-
self.val_image_transform = transforms.Compose([
|
| 277 |
-
transforms.ToTensor(),
|
| 278 |
-
transforms.Resize(self.image_size),
|
| 279 |
-
transforms.Normalize(mean=self.mean, std=self.std)])
|
| 280 |
-
|
| 281 |
-
self.pc_size = pc_size
|
| 282 |
-
self.pc_sharpedge_size = pc_sharpedge_size
|
| 283 |
-
self.sharpedge_label = sharpedge_label
|
| 284 |
-
self.return_normal = return_normal
|
| 285 |
-
|
| 286 |
-
self.padding = padding
|
| 287 |
-
self.padding_ratio_range = padding_ratio_range
|
| 288 |
-
|
| 289 |
-
def train_dataloader(self):
|
| 290 |
-
asl_params = {
|
| 291 |
-
"data_json": f'{self.training_data_list}/train.json',
|
| 292 |
-
"sample_root": self.sample_pcd_dir,
|
| 293 |
-
"image_path": self.image_data_json,
|
| 294 |
-
"image_transform": self.train_image_transform,
|
| 295 |
-
"pc_size": self.pc_size,
|
| 296 |
-
"pc_sharpedge_size": self.pc_sharpedge_size,
|
| 297 |
-
"sharpedge_label": self.sharpedge_label,
|
| 298 |
-
"return_normal": self.return_normal,
|
| 299 |
-
"padding": self.padding,
|
| 300 |
-
"padding_ratio_range": self.padding_ratio_range,
|
| 301 |
-
}
|
| 302 |
-
dataset = ObjaverseDataset(**asl_params)
|
| 303 |
-
return torch.utils.data.DataLoader(
|
| 304 |
-
dataset,
|
| 305 |
-
batch_size=self.batch_size,
|
| 306 |
-
num_workers=self.num_workers,
|
| 307 |
-
pin_memory=True,
|
| 308 |
-
drop_last=True,
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
def val_dataloader(self):
|
| 312 |
-
asl_params = {
|
| 313 |
-
"data_json": f'{self.training_data_list}/val.json',
|
| 314 |
-
"sample_root": self.sample_pcd_dir,
|
| 315 |
-
"image_path": self.image_data_json,
|
| 316 |
-
"image_transform": self.val_image_transform,
|
| 317 |
-
"pc_size": self.pc_size,
|
| 318 |
-
"pc_sharpedge_size": self.pc_sharpedge_size,
|
| 319 |
-
"sharpedge_label": self.sharpedge_label,
|
| 320 |
-
"return_normal": self.return_normal,
|
| 321 |
-
"padding": self.padding,
|
| 322 |
-
"padding_ratio_range": self.padding_ratio_range,
|
| 323 |
-
}
|
| 324 |
-
dataset = ObjaverseDataset(**asl_params)
|
| 325 |
-
return torch.utils.data.DataLoader(
|
| 326 |
-
dataset,
|
| 327 |
-
batch_size=self.batch_size,
|
| 328 |
-
num_workers=self.val_num_workers,
|
| 329 |
-
pin_memory=True,
|
| 330 |
-
drop_last=True,
|
| 331 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/data/objaverse_vae.py
DELETED
|
@@ -1,262 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# ==============================================================================
|
| 4 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 5 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 6 |
-
#
|
| 7 |
-
# Modified by UltraShape on 2025.12.25
|
| 8 |
-
# ==============================================================================
|
| 9 |
-
|
| 10 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 11 |
-
# except for the third-party components listed below.
|
| 12 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 13 |
-
# in the repsective licenses of these third-party components.
|
| 14 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 15 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 16 |
-
# all relevant laws and regulations.
|
| 17 |
-
|
| 18 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 19 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 20 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 21 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 22 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
import os
|
| 26 |
-
import cv2
|
| 27 |
-
import json
|
| 28 |
-
import math
|
| 29 |
-
import random
|
| 30 |
-
import imageio
|
| 31 |
-
import pickle
|
| 32 |
-
import numpy as np
|
| 33 |
-
from PIL import Image
|
| 34 |
-
import pandas as pd
|
| 35 |
-
from dataclasses import dataclass, field
|
| 36 |
-
|
| 37 |
-
import torch
|
| 38 |
-
import torch.nn.functional as F
|
| 39 |
-
import pytorch_lightning as pl
|
| 40 |
-
from torch.utils.data import DataLoader, Dataset
|
| 41 |
-
import torchvision.transforms as transforms
|
| 42 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 43 |
-
from ultrashape.utils.typing import *
|
| 44 |
-
|
| 45 |
-
class ObjaverseDataset(Dataset):
|
| 46 |
-
def __init__(
|
| 47 |
-
self,
|
| 48 |
-
data_json,
|
| 49 |
-
sample_root,
|
| 50 |
-
pc_size: int = 2048,
|
| 51 |
-
pc_sharpedge_size: int = 2048,
|
| 52 |
-
sup_near_uni_size: int = 4096,
|
| 53 |
-
sup_near_sharp_size: int = 4096,
|
| 54 |
-
sup_space_size: int = 4096,
|
| 55 |
-
tsdf_threshold: float = 0.05,
|
| 56 |
-
sharpedge_label: bool = False,
|
| 57 |
-
return_normal: bool = False,
|
| 58 |
-
):
|
| 59 |
-
super().__init__()
|
| 60 |
-
|
| 61 |
-
self.uids = json.load(open(data_json))
|
| 62 |
-
self.sample_root = sample_root
|
| 63 |
-
|
| 64 |
-
self.pc_size = pc_size
|
| 65 |
-
self.pc_sharpedge_size = pc_sharpedge_size
|
| 66 |
-
self.sharpedge_label = sharpedge_label
|
| 67 |
-
self.return_normal = return_normal
|
| 68 |
-
|
| 69 |
-
self.sup_near_uni_size = sup_near_uni_size
|
| 70 |
-
self.sup_near_sharp_size = sup_near_sharp_size
|
| 71 |
-
self.sup_space_size = sup_space_size
|
| 72 |
-
self.tsdf_threshold = tsdf_threshold
|
| 73 |
-
|
| 74 |
-
print(f"Loaded {len(self.uids)} uids from {data_json}.")
|
| 75 |
-
|
| 76 |
-
rank_zero_info(f'*' * 50)
|
| 77 |
-
rank_zero_info(f'Dataset Infos:')
|
| 78 |
-
rank_zero_info(f'# of 3D file: {len(self.uids)}')
|
| 79 |
-
rank_zero_info(f'# of Surface Points: {self.pc_size}')
|
| 80 |
-
rank_zero_info(f'# of Sharpedge Surface Points: {self.pc_sharpedge_size}')
|
| 81 |
-
rank_zero_info(f'# of Uniform Near-Surface Sup-Points: {self.sup_near_uni_size}')
|
| 82 |
-
rank_zero_info(f'# of Sharpedge Near-Surface Sup-Points: {self.sup_near_sharp_size}')
|
| 83 |
-
rank_zero_info(f'# of Random Space Sup-Points: {self.sup_space_size}')
|
| 84 |
-
rank_zero_info(f'Using sharp edge label: {self.sharpedge_label}')
|
| 85 |
-
rank_zero_info(f'*' * 50)
|
| 86 |
-
|
| 87 |
-
def __len__(self):
|
| 88 |
-
return len(self.uids)
|
| 89 |
-
|
| 90 |
-
def _load_shape(self, index: int) -> Dict[str, Any]:
|
| 91 |
-
rng = np.random.default_rng()
|
| 92 |
-
|
| 93 |
-
data = np.load(f'{self.sample_root}/{self.uids[index]}.npz')
|
| 94 |
-
|
| 95 |
-
##################### sup pcd&sdf ######################
|
| 96 |
-
uniform_near_points = (np.asarray(data['uniform_near_points'])-0.5) * 2
|
| 97 |
-
curvature_near_points = (np.asarray(data['curvature_near_points'])-0.5) * 2
|
| 98 |
-
space_points = (np.asarray(data['space_points'])-0.5) * 2
|
| 99 |
-
uniform_near_sdf = np.asarray(data['uniform_near_sdf']) * 2
|
| 100 |
-
curvature_near_sdf = np.asarray(data['curvature_near_sdf']) * 2
|
| 101 |
-
space_sdf = np.asarray(data['space_sdf']) * 2
|
| 102 |
-
|
| 103 |
-
uni_noisy_idx = rng.choice(uniform_near_points.shape[0], self.sup_near_uni_size, replace=False)
|
| 104 |
-
cur_noisy_idx = rng.choice(curvature_near_points.shape[0], self.sup_near_sharp_size, replace=False)
|
| 105 |
-
space_idx = rng.choice(space_points.shape[0], self.sup_space_size, replace=False)
|
| 106 |
-
|
| 107 |
-
uniform_near_points = uniform_near_points[uni_noisy_idx]
|
| 108 |
-
curvature_near_points = curvature_near_points[cur_noisy_idx]
|
| 109 |
-
space_points = space_points[space_idx]
|
| 110 |
-
uniform_near_sdf = uniform_near_sdf[uni_noisy_idx]
|
| 111 |
-
curvature_near_sdf = curvature_near_sdf[cur_noisy_idx]
|
| 112 |
-
space_sdf = space_sdf[space_idx]
|
| 113 |
-
|
| 114 |
-
uniform_near_sdf, curvature_near_sdf, space_sdf = map(self._clip_to_tsdf, (uniform_near_sdf, curvature_near_sdf, space_sdf))
|
| 115 |
-
|
| 116 |
-
surface_og = (np.asarray(data['clean_surface_points'])-0.5) * 2
|
| 117 |
-
normal = np.asarray(data['clean_surface_normals'])
|
| 118 |
-
surface_og_n = np.concatenate([surface_og, normal], axis=1)
|
| 119 |
-
rng = np.random.default_rng()
|
| 120 |
-
|
| 121 |
-
# hard code: first 300k are uniform, last 300k are sharp
|
| 122 |
-
assert surface_og_n.shape[0] == 600000, f"assume that suface points = 30w uniform + 30w curvature, but {len(surface_og_n)=}"
|
| 123 |
-
coarse_surface = surface_og_n[:300000]
|
| 124 |
-
sharp_surface = surface_og_n[300000:]
|
| 125 |
-
|
| 126 |
-
surface_normal = []
|
| 127 |
-
|
| 128 |
-
if self.pc_size > 0:
|
| 129 |
-
ind = rng.choice(coarse_surface.shape[0], self.pc_size // 2, replace=False)
|
| 130 |
-
coarse_surface = coarse_surface[ind]
|
| 131 |
-
if self.sharpedge_label:
|
| 132 |
-
sharpedge_label = np.zeros((self.pc_size // 2, 1))
|
| 133 |
-
coarse_surface = np.concatenate((coarse_surface, sharpedge_label), axis=1)
|
| 134 |
-
surface_normal.append(coarse_surface)
|
| 135 |
-
|
| 136 |
-
ind_sharpedge = rng.choice(sharp_surface.shape[0], self.pc_size // 2, replace=False)
|
| 137 |
-
sharp_surface = sharp_surface[ind_sharpedge]
|
| 138 |
-
if self.sharpedge_label:
|
| 139 |
-
sharpedge_label = np.ones((self.pc_size // 2, 1))
|
| 140 |
-
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
|
| 141 |
-
surface_normal.append(sharp_surface)
|
| 142 |
-
|
| 143 |
-
surface_normal = np.concatenate(surface_normal, axis=0)
|
| 144 |
-
surface_normal = torch.FloatTensor(surface_normal)
|
| 145 |
-
surface = surface_normal[:, 0:3]
|
| 146 |
-
normal = surface_normal[:, 3:6]
|
| 147 |
-
assert surface.shape[0] == self.pc_size + self.pc_sharpedge_size
|
| 148 |
-
|
| 149 |
-
geo_points = 0.0
|
| 150 |
-
normal = torch.nn.functional.normalize(normal, p=2, dim=1)
|
| 151 |
-
if self.return_normal:
|
| 152 |
-
surface = torch.cat([surface, normal], dim=-1)
|
| 153 |
-
if self.sharpedge_label:
|
| 154 |
-
surface = torch.cat([surface, surface_normal[:, -1:]], dim=-1)
|
| 155 |
-
|
| 156 |
-
ret = {
|
| 157 |
-
"uid": self.uids[index],
|
| 158 |
-
"surface": surface,
|
| 159 |
-
"sup_near_uniform": np.concatenate([uniform_near_points, uniform_near_sdf[...,None]], axis=1),
|
| 160 |
-
"sup_near_sharp": np.concatenate([curvature_near_points, curvature_near_sdf[...,None]], axis=1),
|
| 161 |
-
"sup_space": np.concatenate([space_points, space_sdf[...,None]], axis=1),
|
| 162 |
-
"geo_points": geo_points
|
| 163 |
-
}
|
| 164 |
-
return ret
|
| 165 |
-
|
| 166 |
-
def _clip_to_tsdf(self, sdf: np.array):
|
| 167 |
-
nan_mask = np.isnan(sdf)
|
| 168 |
-
if np.any(nan_mask):
|
| 169 |
-
sdf=np.nan_to_num(sdf, nan=1.0, posinf=1.0, neginf=-1.0)
|
| 170 |
-
return sdf.flatten().astype(np.float32).clip(-self.tsdf_threshold, self.tsdf_threshold) / self.tsdf_threshold
|
| 171 |
-
|
| 172 |
-
def get_data(self, index):
|
| 173 |
-
ret = self._load_shape(index)
|
| 174 |
-
return ret
|
| 175 |
-
|
| 176 |
-
def __getitem__(self, index):
|
| 177 |
-
return self.get_data(index)
|
| 178 |
-
|
| 179 |
-
def collate(self, batch):
|
| 180 |
-
batch = torch.utils.data.default_collate(batch)
|
| 181 |
-
return batch
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
class ObjaverseDataModule(pl.LightningDataModule):
|
| 185 |
-
def __init__(
|
| 186 |
-
self,
|
| 187 |
-
batch_size: int = 1,
|
| 188 |
-
num_workers: int = 4,
|
| 189 |
-
val_num_workers: int = 2,
|
| 190 |
-
training_data_list: str = None,
|
| 191 |
-
sample_pcd_dir: str = None,
|
| 192 |
-
pc_size: int = 2048,
|
| 193 |
-
pc_sharpedge_size: int = 2048,
|
| 194 |
-
sup_near_uni_size: int = 4096,
|
| 195 |
-
sup_near_sharp_size: int = 4096,
|
| 196 |
-
sup_space_size: int = 4096,
|
| 197 |
-
tsdf_threshold: float = 0.05,
|
| 198 |
-
sharpedge_label: bool = False,
|
| 199 |
-
return_normal: bool = False,
|
| 200 |
-
):
|
| 201 |
-
|
| 202 |
-
super().__init__()
|
| 203 |
-
self.batch_size = batch_size
|
| 204 |
-
self.num_workers = num_workers
|
| 205 |
-
self.val_num_workers = val_num_workers
|
| 206 |
-
|
| 207 |
-
self.training_data_list = training_data_list
|
| 208 |
-
self.sample_pcd_dir = sample_pcd_dir
|
| 209 |
-
|
| 210 |
-
self.pc_size = pc_size
|
| 211 |
-
self.pc_sharpedge_size = pc_sharpedge_size
|
| 212 |
-
self.sharpedge_label = sharpedge_label
|
| 213 |
-
self.return_normal = return_normal
|
| 214 |
-
|
| 215 |
-
self.sup_near_uni_size = sup_near_uni_size
|
| 216 |
-
self.sup_near_sharp_size = sup_near_sharp_size
|
| 217 |
-
self.sup_space_size = sup_space_size
|
| 218 |
-
self.tsdf_threshold = tsdf_threshold
|
| 219 |
-
|
| 220 |
-
def train_dataloader(self):
|
| 221 |
-
asl_params = {
|
| 222 |
-
"data_json": f'{self.training_data_list}/train.json',
|
| 223 |
-
"sample_root": self.sample_pcd_dir,
|
| 224 |
-
"pc_size": self.pc_size,
|
| 225 |
-
"pc_sharpedge_size": self.pc_sharpedge_size,
|
| 226 |
-
"sup_near_uni_size": self.sup_near_uni_size,
|
| 227 |
-
"sup_near_sharp_size": self.sup_near_sharp_size,
|
| 228 |
-
"sup_space_size": self.sup_space_size,
|
| 229 |
-
"tsdf_threshold": self.tsdf_threshold,
|
| 230 |
-
"sharpedge_label": self.sharpedge_label,
|
| 231 |
-
"return_normal": self.return_normal,
|
| 232 |
-
}
|
| 233 |
-
dataset = ObjaverseDataset(**asl_params)
|
| 234 |
-
return torch.utils.data.DataLoader(
|
| 235 |
-
dataset,
|
| 236 |
-
batch_size=self.batch_size,
|
| 237 |
-
num_workers=self.num_workers,
|
| 238 |
-
pin_memory=True,
|
| 239 |
-
drop_last=True,
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
def val_dataloader(self):
|
| 243 |
-
asl_params = {
|
| 244 |
-
"data_json": f'{self.training_data_list}/val.json',
|
| 245 |
-
"sample_root": self.sample_pcd_dir,
|
| 246 |
-
"pc_size": self.pc_size,
|
| 247 |
-
"pc_sharpedge_size": self.pc_sharpedge_size,
|
| 248 |
-
"sup_near_uni_size": self.sup_near_uni_size,
|
| 249 |
-
"sup_near_sharp_size": self.sup_near_sharp_size,
|
| 250 |
-
"sup_space_size": self.sup_space_size,
|
| 251 |
-
"tsdf_threshold": self.tsdf_threshold,
|
| 252 |
-
"sharpedge_label": self.sharpedge_label,
|
| 253 |
-
"return_normal": self.return_normal,
|
| 254 |
-
}
|
| 255 |
-
dataset = ObjaverseDataset(**asl_params)
|
| 256 |
-
return torch.utils.data.DataLoader(
|
| 257 |
-
dataset,
|
| 258 |
-
batch_size=self.batch_size,
|
| 259 |
-
num_workers=self.val_num_workers,
|
| 260 |
-
pin_memory=True,
|
| 261 |
-
drop_last=True,
|
| 262 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/data/utils.py
DELETED
|
@@ -1,193 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# ==============================================================================
|
| 4 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 5 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 6 |
-
#
|
| 7 |
-
# Modified by UltraShape on 2025.12.25
|
| 8 |
-
# ==============================================================================
|
| 9 |
-
|
| 10 |
-
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
|
| 11 |
-
# This file is part of the WebDataset library.
|
| 12 |
-
# See the LICENSE file for licensing terms (BSD-style).
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
"""Miscellaneous utility functions."""
|
| 16 |
-
|
| 17 |
-
import importlib
|
| 18 |
-
import itertools as itt
|
| 19 |
-
import os
|
| 20 |
-
import re
|
| 21 |
-
import sys
|
| 22 |
-
from typing import Any, Callable, Iterator, Union
|
| 23 |
-
import torch
|
| 24 |
-
import numpy as np
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def make_seed(*args):
|
| 28 |
-
seed = 0
|
| 29 |
-
for arg in args:
|
| 30 |
-
seed = (seed * 31 + hash(arg)) & 0x7FFFFFFF
|
| 31 |
-
return seed
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class PipelineStage:
|
| 35 |
-
def invoke(self, *args, **kw):
|
| 36 |
-
raise NotImplementedError
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def identity(x: Any) -> Any:
|
| 40 |
-
"""Return the argument as is."""
|
| 41 |
-
return x
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def safe_eval(s: str, expr: str = "{}"):
|
| 45 |
-
"""Evaluate the given expression more safely."""
|
| 46 |
-
if re.sub("[^A-Za-z0-9_]", "", s) != s:
|
| 47 |
-
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
|
| 48 |
-
return eval(expr.format(s))
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def lookup_sym(sym: str, modules: list):
|
| 52 |
-
"""Look up a symbol in a list of modules."""
|
| 53 |
-
for mname in modules:
|
| 54 |
-
module = importlib.import_module(mname, package="webdataset")
|
| 55 |
-
result = getattr(module, sym, None)
|
| 56 |
-
if result is not None:
|
| 57 |
-
return result
|
| 58 |
-
return None
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def repeatedly0(
|
| 62 |
-
loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
|
| 63 |
-
):
|
| 64 |
-
"""Repeatedly returns batches from a DataLoader."""
|
| 65 |
-
for _ in range(nepochs):
|
| 66 |
-
yield from itt.islice(loader, nbatches)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def guess_batchsize(batch: Union[tuple, list]):
|
| 70 |
-
"""Guess the batch size by looking at the length of the first element in a tuple."""
|
| 71 |
-
return len(batch[0])
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def repeatedly(
|
| 75 |
-
source: Iterator,
|
| 76 |
-
nepochs: int = None,
|
| 77 |
-
nbatches: int = None,
|
| 78 |
-
nsamples: int = None,
|
| 79 |
-
batchsize: Callable[..., int] = guess_batchsize,
|
| 80 |
-
):
|
| 81 |
-
"""Repeatedly yield samples from an iterator."""
|
| 82 |
-
epoch = 0
|
| 83 |
-
batch = 0
|
| 84 |
-
total = 0
|
| 85 |
-
while True:
|
| 86 |
-
for sample in source:
|
| 87 |
-
yield sample
|
| 88 |
-
batch += 1
|
| 89 |
-
if nbatches is not None and batch >= nbatches:
|
| 90 |
-
return
|
| 91 |
-
if nsamples is not None:
|
| 92 |
-
total += guess_batchsize(sample)
|
| 93 |
-
if total >= nsamples:
|
| 94 |
-
return
|
| 95 |
-
epoch += 1
|
| 96 |
-
if nepochs is not None and epoch >= nepochs:
|
| 97 |
-
return
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
|
| 101 |
-
"""Return node and worker info for PyTorch and some distributed environments."""
|
| 102 |
-
rank = 0
|
| 103 |
-
world_size = 1
|
| 104 |
-
worker = 0
|
| 105 |
-
num_workers = 1
|
| 106 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 107 |
-
rank = int(os.environ["RANK"])
|
| 108 |
-
world_size = int(os.environ["WORLD_SIZE"])
|
| 109 |
-
else:
|
| 110 |
-
try:
|
| 111 |
-
import torch.distributed
|
| 112 |
-
|
| 113 |
-
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 114 |
-
group = group or torch.distributed.group.WORLD
|
| 115 |
-
rank = torch.distributed.get_rank(group=group)
|
| 116 |
-
world_size = torch.distributed.get_world_size(group=group)
|
| 117 |
-
except ModuleNotFoundError:
|
| 118 |
-
pass
|
| 119 |
-
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
| 120 |
-
worker = int(os.environ["WORKER"])
|
| 121 |
-
num_workers = int(os.environ["NUM_WORKERS"])
|
| 122 |
-
else:
|
| 123 |
-
try:
|
| 124 |
-
import torch.utils.data
|
| 125 |
-
|
| 126 |
-
worker_info = torch.utils.data.get_worker_info()
|
| 127 |
-
if worker_info is not None:
|
| 128 |
-
worker = worker_info.id
|
| 129 |
-
num_workers = worker_info.num_workers
|
| 130 |
-
except ModuleNotFoundError:
|
| 131 |
-
pass
|
| 132 |
-
|
| 133 |
-
return rank, world_size, worker, num_workers
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def pytorch_worker_seed(group=None):
|
| 137 |
-
"""Compute a distinct, deterministic RNG seed for each worker and node."""
|
| 138 |
-
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
|
| 139 |
-
return rank * 1000 + worker
|
| 140 |
-
|
| 141 |
-
def worker_init_fn(_):
|
| 142 |
-
worker_info = torch.utils.data.get_worker_info()
|
| 143 |
-
worker_id = worker_info.id
|
| 144 |
-
|
| 145 |
-
# dataset = worker_info.dataset
|
| 146 |
-
# split_size = dataset.num_records // worker_info.num_workers
|
| 147 |
-
# # reset num_records to the true number to retain reliable length information
|
| 148 |
-
# dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
| 149 |
-
# current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
| 150 |
-
# return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
| 151 |
-
|
| 152 |
-
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
| 156 |
-
"""
|
| 157 |
-
|
| 158 |
-
Args:
|
| 159 |
-
samples (list[dict]):
|
| 160 |
-
combine_tensors:
|
| 161 |
-
combine_scalars:
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
|
| 165 |
-
"""
|
| 166 |
-
|
| 167 |
-
result = {}
|
| 168 |
-
|
| 169 |
-
keys = samples[0].keys()
|
| 170 |
-
|
| 171 |
-
for key in keys:
|
| 172 |
-
result[key] = []
|
| 173 |
-
|
| 174 |
-
for sample in samples:
|
| 175 |
-
for key in keys:
|
| 176 |
-
val = sample[key]
|
| 177 |
-
result[key].append(val)
|
| 178 |
-
|
| 179 |
-
for key in keys:
|
| 180 |
-
val_list = result[key]
|
| 181 |
-
if isinstance(val_list[0], (int, float)):
|
| 182 |
-
if combine_scalars:
|
| 183 |
-
result[key] = np.array(result[key])
|
| 184 |
-
|
| 185 |
-
elif isinstance(val_list[0], torch.Tensor):
|
| 186 |
-
if combine_tensors:
|
| 187 |
-
result[key] = torch.stack(val_list)
|
| 188 |
-
|
| 189 |
-
elif isinstance(val_list[0], np.ndarray):
|
| 190 |
-
if combine_tensors:
|
| 191 |
-
result[key] = np.stack(val_list)
|
| 192 |
-
|
| 193 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/__init__.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
# Open Source Model Licensed under the Apache License Version 2.0
|
| 2 |
-
# and Other Licenses of the Third-Party Components therein:
|
| 3 |
-
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 4 |
-
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 5 |
-
|
| 6 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 7 |
-
# The below software and/or models in this distribution may have been
|
| 8 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
| 9 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 10 |
-
|
| 11 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 12 |
-
# except for the third-party components listed below.
|
| 13 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 14 |
-
# in the repsective licenses of these third-party components.
|
| 15 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 16 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 17 |
-
# all relevant laws and regulations.
|
| 18 |
-
|
| 19 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 20 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 21 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 22 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 23 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 24 |
-
|
| 25 |
-
from .autoencoders import ShapeVAE
|
| 26 |
-
from .conditioner_mask import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
| 27 |
-
from .denoisers import RefineDiT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
from .attention_blocks import CrossAttentionDecoder
|
| 16 |
-
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, \
|
| 17 |
-
FlashVDMTopMCrossAttentionProcessor
|
| 18 |
-
from .model import ShapeVAE, VectsetVAE
|
| 19 |
-
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
|
| 20 |
-
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
|
| 21 |
-
from .vae_trainer import VAETrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/attention_blocks.py
DELETED
|
@@ -1,711 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Open Source Model Licensed under the Apache License Version 2.0
|
| 9 |
-
# and Other Licenses of the Third-Party Components therein:
|
| 10 |
-
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 11 |
-
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 12 |
-
|
| 13 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 14 |
-
# The below software and/or models in this distribution may have been
|
| 15 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
| 16 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 17 |
-
|
| 18 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 19 |
-
# except for the third-party components listed below.
|
| 20 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 21 |
-
# in the repsective licenses of these third-party components.
|
| 22 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 23 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 24 |
-
# all relevant laws and regulations.
|
| 25 |
-
|
| 26 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 27 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 28 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 29 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 30 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
import os
|
| 34 |
-
from typing import Optional, Union, List
|
| 35 |
-
|
| 36 |
-
import torch
|
| 37 |
-
import torch.nn as nn
|
| 38 |
-
from einops import rearrange
|
| 39 |
-
from torch import Tensor
|
| 40 |
-
|
| 41 |
-
from .attention_processors import CrossAttentionProcessor
|
| 42 |
-
from ...utils import logger
|
| 43 |
-
from ultrashape.utils import voxelize_from_point
|
| 44 |
-
|
| 45 |
-
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
| 46 |
-
|
| 47 |
-
if os.environ.get('USE_SAGEATTN', '0') == '1':
|
| 48 |
-
try:
|
| 49 |
-
from sageattention import sageattn
|
| 50 |
-
except ImportError:
|
| 51 |
-
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
| 52 |
-
scaled_dot_product_attention = sageattn
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class FourierEmbedder(nn.Module):
|
| 56 |
-
""" The sin/cosine positional embedding. """
|
| 57 |
-
|
| 58 |
-
def __init__(self,
|
| 59 |
-
num_freqs: int = 6,
|
| 60 |
-
logspace: bool = True,
|
| 61 |
-
input_dim: int = 3,
|
| 62 |
-
include_input: bool = True,
|
| 63 |
-
include_pi: bool = True) -> None:
|
| 64 |
-
|
| 65 |
-
super().__init__()
|
| 66 |
-
|
| 67 |
-
if logspace:
|
| 68 |
-
frequencies = 2.0 ** torch.arange(
|
| 69 |
-
num_freqs,
|
| 70 |
-
dtype=torch.float32
|
| 71 |
-
)
|
| 72 |
-
else:
|
| 73 |
-
frequencies = torch.linspace(
|
| 74 |
-
1.0,
|
| 75 |
-
2.0 ** (num_freqs - 1),
|
| 76 |
-
num_freqs,
|
| 77 |
-
dtype=torch.float32
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
if include_pi:
|
| 81 |
-
frequencies *= torch.pi
|
| 82 |
-
|
| 83 |
-
self.register_buffer("frequencies", frequencies, persistent=False)
|
| 84 |
-
self.include_input = include_input
|
| 85 |
-
self.num_freqs = num_freqs
|
| 86 |
-
|
| 87 |
-
self.out_dim = self.get_dims(input_dim)
|
| 88 |
-
|
| 89 |
-
def get_dims(self, input_dim):
|
| 90 |
-
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
| 91 |
-
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
| 92 |
-
|
| 93 |
-
return out_dim
|
| 94 |
-
|
| 95 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 96 |
-
""" Forward process.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
x: tensor of shape [..., dim]
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
| 103 |
-
where temp is 1 if include_input is True and 0 otherwise.
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
if self.num_freqs > 0:
|
| 107 |
-
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
| 108 |
-
if self.include_input:
|
| 109 |
-
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
| 110 |
-
else:
|
| 111 |
-
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
| 112 |
-
else:
|
| 113 |
-
return x
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class DropPath(nn.Module):
|
| 117 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 118 |
-
"""
|
| 119 |
-
|
| 120 |
-
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 121 |
-
super(DropPath, self).__init__()
|
| 122 |
-
self.drop_prob = drop_prob
|
| 123 |
-
self.scale_by_keep = scale_by_keep
|
| 124 |
-
|
| 125 |
-
def forward(self, x):
|
| 126 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 127 |
-
|
| 128 |
-
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 129 |
-
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 130 |
-
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 131 |
-
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 132 |
-
'survival rate' as the argument.
|
| 133 |
-
|
| 134 |
-
"""
|
| 135 |
-
if self.drop_prob == 0. or not self.training:
|
| 136 |
-
return x
|
| 137 |
-
keep_prob = 1 - self.drop_prob
|
| 138 |
-
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 139 |
-
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 140 |
-
if keep_prob > 0.0 and self.scale_by_keep:
|
| 141 |
-
random_tensor.div_(keep_prob)
|
| 142 |
-
return x * random_tensor
|
| 143 |
-
|
| 144 |
-
def extra_repr(self):
|
| 145 |
-
return f'drop_prob={round(self.drop_prob, 3):0.3f}'
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
class MLP(nn.Module):
|
| 149 |
-
def __init__(
|
| 150 |
-
self, *,
|
| 151 |
-
width: int,
|
| 152 |
-
expand_ratio: int = 4,
|
| 153 |
-
output_width: int = None,
|
| 154 |
-
drop_path_rate: float = 0.0
|
| 155 |
-
):
|
| 156 |
-
super().__init__()
|
| 157 |
-
self.width = width
|
| 158 |
-
self.c_fc = nn.Linear(width, width * expand_ratio)
|
| 159 |
-
self.c_proj = nn.Linear(width * expand_ratio, output_width if output_width is not None else width)
|
| 160 |
-
self.gelu = nn.GELU()
|
| 161 |
-
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 162 |
-
|
| 163 |
-
def forward(self, x):
|
| 164 |
-
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
class QKVMultiheadCrossAttention(nn.Module):
|
| 168 |
-
def __init__(
|
| 169 |
-
self,
|
| 170 |
-
*,
|
| 171 |
-
heads: int,
|
| 172 |
-
n_data: Optional[int] = None,
|
| 173 |
-
width=None,
|
| 174 |
-
qk_norm=False,
|
| 175 |
-
norm_layer=nn.LayerNorm
|
| 176 |
-
):
|
| 177 |
-
super().__init__()
|
| 178 |
-
self.heads = heads
|
| 179 |
-
self.n_data = n_data
|
| 180 |
-
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 181 |
-
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 182 |
-
|
| 183 |
-
self.attn_processor = CrossAttentionProcessor()
|
| 184 |
-
|
| 185 |
-
def forward(self, q, kv):
|
| 186 |
-
_, n_ctx, _ = q.shape
|
| 187 |
-
bs, n_data, width = kv.shape
|
| 188 |
-
attn_ch = width // self.heads // 2
|
| 189 |
-
q = q.view(bs, n_ctx, self.heads, -1)
|
| 190 |
-
kv = kv.view(bs, n_data, self.heads, -1)
|
| 191 |
-
k, v = torch.split(kv, attn_ch, dim=-1)
|
| 192 |
-
|
| 193 |
-
q = self.q_norm(q)
|
| 194 |
-
k = self.k_norm(k)
|
| 195 |
-
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
| 196 |
-
out = self.attn_processor(self, q, k, v)
|
| 197 |
-
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 198 |
-
return out
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
class MultiheadCrossAttention(nn.Module):
|
| 202 |
-
def __init__(
|
| 203 |
-
self,
|
| 204 |
-
*,
|
| 205 |
-
width: int,
|
| 206 |
-
heads: int,
|
| 207 |
-
qkv_bias: bool = True,
|
| 208 |
-
n_data: Optional[int] = None,
|
| 209 |
-
data_width: Optional[int] = None,
|
| 210 |
-
norm_layer=nn.LayerNorm,
|
| 211 |
-
qk_norm: bool = False,
|
| 212 |
-
kv_cache: bool = False,
|
| 213 |
-
):
|
| 214 |
-
super().__init__()
|
| 215 |
-
self.n_data = n_data
|
| 216 |
-
self.width = width
|
| 217 |
-
self.heads = heads
|
| 218 |
-
self.data_width = width if data_width is None else data_width
|
| 219 |
-
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
| 220 |
-
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
| 221 |
-
self.c_proj = nn.Linear(width, width)
|
| 222 |
-
self.attention = QKVMultiheadCrossAttention(
|
| 223 |
-
heads=heads,
|
| 224 |
-
n_data=n_data,
|
| 225 |
-
width=width,
|
| 226 |
-
norm_layer=norm_layer,
|
| 227 |
-
qk_norm=qk_norm
|
| 228 |
-
)
|
| 229 |
-
self.kv_cache = kv_cache
|
| 230 |
-
self.data = None
|
| 231 |
-
|
| 232 |
-
def forward(self, x, data):
|
| 233 |
-
x = self.c_q(x)
|
| 234 |
-
if self.kv_cache:
|
| 235 |
-
if self.data is None:
|
| 236 |
-
self.data = self.c_kv(data)
|
| 237 |
-
logger.info('Save kv cache,this should be called only once for one mesh')
|
| 238 |
-
data = self.data
|
| 239 |
-
else:
|
| 240 |
-
data = self.c_kv(data)
|
| 241 |
-
x = self.attention(x, data)
|
| 242 |
-
x = self.c_proj(x)
|
| 243 |
-
return x
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
class ResidualCrossAttentionBlock(nn.Module):
|
| 247 |
-
def __init__(
|
| 248 |
-
self,
|
| 249 |
-
*,
|
| 250 |
-
n_data: Optional[int] = None,
|
| 251 |
-
width: int,
|
| 252 |
-
heads: int,
|
| 253 |
-
mlp_expand_ratio: int = 4,
|
| 254 |
-
data_width: Optional[int] = None,
|
| 255 |
-
qkv_bias: bool = True,
|
| 256 |
-
norm_layer=nn.LayerNorm,
|
| 257 |
-
qk_norm: bool = False
|
| 258 |
-
):
|
| 259 |
-
super().__init__()
|
| 260 |
-
|
| 261 |
-
if data_width is None:
|
| 262 |
-
data_width = width
|
| 263 |
-
|
| 264 |
-
self.attn = MultiheadCrossAttention(
|
| 265 |
-
n_data=n_data,
|
| 266 |
-
width=width,
|
| 267 |
-
heads=heads,
|
| 268 |
-
data_width=data_width,
|
| 269 |
-
qkv_bias=qkv_bias,
|
| 270 |
-
norm_layer=norm_layer,
|
| 271 |
-
qk_norm=qk_norm
|
| 272 |
-
)
|
| 273 |
-
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 274 |
-
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
| 275 |
-
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 276 |
-
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
| 277 |
-
|
| 278 |
-
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
| 279 |
-
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
| 280 |
-
x = x + self.mlp(self.ln_3(x))
|
| 281 |
-
return x
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
class QKVMultiheadAttention(nn.Module):
|
| 285 |
-
def __init__(
|
| 286 |
-
self,
|
| 287 |
-
*,
|
| 288 |
-
heads: int,
|
| 289 |
-
n_ctx: int,
|
| 290 |
-
width=None,
|
| 291 |
-
qk_norm=False,
|
| 292 |
-
norm_layer=nn.LayerNorm
|
| 293 |
-
):
|
| 294 |
-
super().__init__()
|
| 295 |
-
self.heads = heads
|
| 296 |
-
self.n_ctx = n_ctx
|
| 297 |
-
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 298 |
-
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 299 |
-
|
| 300 |
-
def forward(self, qkv):
|
| 301 |
-
bs, n_ctx, width = qkv.shape
|
| 302 |
-
attn_ch = width // self.heads // 3
|
| 303 |
-
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
| 304 |
-
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
| 305 |
-
|
| 306 |
-
q = self.q_norm(q)
|
| 307 |
-
k = self.k_norm(k)
|
| 308 |
-
|
| 309 |
-
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
|
| 310 |
-
out = scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 311 |
-
return out
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
class MultiheadAttention(nn.Module):
|
| 315 |
-
def __init__(
|
| 316 |
-
self,
|
| 317 |
-
*,
|
| 318 |
-
n_ctx: int,
|
| 319 |
-
width: int,
|
| 320 |
-
heads: int,
|
| 321 |
-
qkv_bias: bool,
|
| 322 |
-
norm_layer=nn.LayerNorm,
|
| 323 |
-
qk_norm: bool = False,
|
| 324 |
-
drop_path_rate: float = 0.0
|
| 325 |
-
):
|
| 326 |
-
super().__init__()
|
| 327 |
-
self.n_ctx = n_ctx
|
| 328 |
-
self.width = width
|
| 329 |
-
self.heads = heads
|
| 330 |
-
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
| 331 |
-
self.c_proj = nn.Linear(width, width)
|
| 332 |
-
self.attention = QKVMultiheadAttention(
|
| 333 |
-
heads=heads,
|
| 334 |
-
n_ctx=n_ctx,
|
| 335 |
-
width=width,
|
| 336 |
-
norm_layer=norm_layer,
|
| 337 |
-
qk_norm=qk_norm
|
| 338 |
-
)
|
| 339 |
-
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 340 |
-
|
| 341 |
-
def forward(self, x):
|
| 342 |
-
x = self.c_qkv(x)
|
| 343 |
-
x = self.attention(x)
|
| 344 |
-
x = self.drop_path(self.c_proj(x))
|
| 345 |
-
return x
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
class ResidualAttentionBlock(nn.Module):
|
| 349 |
-
def __init__(
|
| 350 |
-
self,
|
| 351 |
-
*,
|
| 352 |
-
n_ctx: int,
|
| 353 |
-
width: int,
|
| 354 |
-
heads: int,
|
| 355 |
-
qkv_bias: bool = True,
|
| 356 |
-
norm_layer=nn.LayerNorm,
|
| 357 |
-
qk_norm: bool = False,
|
| 358 |
-
drop_path_rate: float = 0.0,
|
| 359 |
-
):
|
| 360 |
-
super().__init__()
|
| 361 |
-
self.attn = MultiheadAttention(
|
| 362 |
-
n_ctx=n_ctx,
|
| 363 |
-
width=width,
|
| 364 |
-
heads=heads,
|
| 365 |
-
qkv_bias=qkv_bias,
|
| 366 |
-
norm_layer=norm_layer,
|
| 367 |
-
qk_norm=qk_norm,
|
| 368 |
-
drop_path_rate=drop_path_rate
|
| 369 |
-
)
|
| 370 |
-
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 371 |
-
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
| 372 |
-
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 373 |
-
|
| 374 |
-
def forward(self, x: torch.Tensor):
|
| 375 |
-
x = x + self.attn(self.ln_1(x))
|
| 376 |
-
x = x + self.mlp(self.ln_2(x))
|
| 377 |
-
return x
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
class Transformer(nn.Module):
|
| 381 |
-
def __init__(
|
| 382 |
-
self,
|
| 383 |
-
*,
|
| 384 |
-
n_ctx: int,
|
| 385 |
-
width: int,
|
| 386 |
-
layers: int,
|
| 387 |
-
heads: int,
|
| 388 |
-
qkv_bias: bool = True,
|
| 389 |
-
norm_layer=nn.LayerNorm,
|
| 390 |
-
qk_norm: bool = False,
|
| 391 |
-
drop_path_rate: float = 0.0
|
| 392 |
-
):
|
| 393 |
-
super().__init__()
|
| 394 |
-
self.n_ctx = n_ctx
|
| 395 |
-
self.width = width
|
| 396 |
-
self.layers = layers
|
| 397 |
-
self.resblocks = nn.ModuleList(
|
| 398 |
-
[
|
| 399 |
-
ResidualAttentionBlock(
|
| 400 |
-
n_ctx=n_ctx,
|
| 401 |
-
width=width,
|
| 402 |
-
heads=heads,
|
| 403 |
-
qkv_bias=qkv_bias,
|
| 404 |
-
norm_layer=norm_layer,
|
| 405 |
-
qk_norm=qk_norm,
|
| 406 |
-
drop_path_rate=drop_path_rate
|
| 407 |
-
)
|
| 408 |
-
for _ in range(layers)
|
| 409 |
-
]
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
def forward(self, x: torch.Tensor):
|
| 413 |
-
for block in self.resblocks:
|
| 414 |
-
x = block(x)
|
| 415 |
-
return x
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
class CrossAttentionDecoder(nn.Module):
|
| 419 |
-
|
| 420 |
-
def __init__(
|
| 421 |
-
self,
|
| 422 |
-
*,
|
| 423 |
-
num_latents: int,
|
| 424 |
-
out_channels: int,
|
| 425 |
-
fourier_embedder: FourierEmbedder,
|
| 426 |
-
width: int,
|
| 427 |
-
heads: int,
|
| 428 |
-
mlp_expand_ratio: int = 4,
|
| 429 |
-
downsample_ratio: int = 1,
|
| 430 |
-
enable_ln_post: bool = True,
|
| 431 |
-
qkv_bias: bool = True,
|
| 432 |
-
qk_norm: bool = False,
|
| 433 |
-
label_type: str = "binary"
|
| 434 |
-
):
|
| 435 |
-
super().__init__()
|
| 436 |
-
|
| 437 |
-
self.enable_ln_post = enable_ln_post
|
| 438 |
-
self.fourier_embedder = fourier_embedder
|
| 439 |
-
self.downsample_ratio = downsample_ratio
|
| 440 |
-
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
| 441 |
-
if self.downsample_ratio != 1:
|
| 442 |
-
self.latents_proj = nn.Linear(width * downsample_ratio, width)
|
| 443 |
-
if self.enable_ln_post == False:
|
| 444 |
-
qk_norm = False
|
| 445 |
-
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
| 446 |
-
n_data=num_latents,
|
| 447 |
-
width=width,
|
| 448 |
-
mlp_expand_ratio=mlp_expand_ratio,
|
| 449 |
-
heads=heads,
|
| 450 |
-
qkv_bias=qkv_bias,
|
| 451 |
-
qk_norm=qk_norm
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
if self.enable_ln_post:
|
| 455 |
-
self.ln_post = nn.LayerNorm(width)
|
| 456 |
-
self.output_proj = nn.Linear(width, out_channels)
|
| 457 |
-
self.label_type = label_type
|
| 458 |
-
self.count = 0
|
| 459 |
-
|
| 460 |
-
def set_cross_attention_processor(self, processor):
|
| 461 |
-
self.cross_attn_decoder.attn.attention.attn_processor = processor
|
| 462 |
-
|
| 463 |
-
def set_default_cross_attention_processor(self):
|
| 464 |
-
self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
|
| 465 |
-
|
| 466 |
-
def forward(self, queries=None, query_embeddings=None, latents=None):
|
| 467 |
-
if query_embeddings is None:
|
| 468 |
-
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
|
| 469 |
-
self.count += query_embeddings.shape[1]
|
| 470 |
-
if self.downsample_ratio != 1:
|
| 471 |
-
latents = self.latents_proj(latents)
|
| 472 |
-
x = self.cross_attn_decoder(query_embeddings, latents)
|
| 473 |
-
if self.enable_ln_post:
|
| 474 |
-
x = self.ln_post(x)
|
| 475 |
-
occ = self.output_proj(x)
|
| 476 |
-
return occ
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
def fps(
|
| 480 |
-
src: torch.Tensor,
|
| 481 |
-
batch: Optional[Tensor] = None,
|
| 482 |
-
ratio: Optional[Union[Tensor, float]] = None,
|
| 483 |
-
random_start: bool = True,
|
| 484 |
-
batch_size: Optional[int] = None,
|
| 485 |
-
ptr: Optional[Union[Tensor, List[int]]] = None,
|
| 486 |
-
):
|
| 487 |
-
src = src.float()
|
| 488 |
-
from torch_cluster import fps as fps_fn
|
| 489 |
-
output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
|
| 490 |
-
return output
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
class PointCrossAttentionEncoder(nn.Module):
|
| 494 |
-
|
| 495 |
-
def __init__(
|
| 496 |
-
self, *,
|
| 497 |
-
num_latents: int,
|
| 498 |
-
downsample_ratio: float,
|
| 499 |
-
pc_size: int,
|
| 500 |
-
pc_sharpedge_size: int,
|
| 501 |
-
fourier_embedder: FourierEmbedder,
|
| 502 |
-
point_feats: int,
|
| 503 |
-
width: int,
|
| 504 |
-
heads: int,
|
| 505 |
-
layers: int,
|
| 506 |
-
voxel_query_res: int,
|
| 507 |
-
normal_pe: bool = False,
|
| 508 |
-
qkv_bias: bool = True,
|
| 509 |
-
use_ln_post: bool = False,
|
| 510 |
-
use_checkpoint: bool = False,
|
| 511 |
-
qk_norm: bool = False,
|
| 512 |
-
jitter_query: bool = False,
|
| 513 |
-
voxel_query: bool = False,
|
| 514 |
-
):
|
| 515 |
-
|
| 516 |
-
super().__init__()
|
| 517 |
-
|
| 518 |
-
self.use_checkpoint = use_checkpoint
|
| 519 |
-
self.num_latents = num_latents
|
| 520 |
-
self.downsample_ratio = downsample_ratio
|
| 521 |
-
self.point_feats = point_feats
|
| 522 |
-
self.normal_pe = normal_pe
|
| 523 |
-
self.jitter_query = jitter_query
|
| 524 |
-
self.voxel_query = voxel_query
|
| 525 |
-
self.voxel_query_res = voxel_query_res
|
| 526 |
-
|
| 527 |
-
if pc_sharpedge_size == 0:
|
| 528 |
-
print(
|
| 529 |
-
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is zero')
|
| 530 |
-
else:
|
| 531 |
-
print(
|
| 532 |
-
f'PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}')
|
| 533 |
-
|
| 534 |
-
self.pc_size = pc_size
|
| 535 |
-
self.pc_sharpedge_size = pc_sharpedge_size
|
| 536 |
-
|
| 537 |
-
self.fourier_embedder = fourier_embedder
|
| 538 |
-
|
| 539 |
-
if self.jitter_query or self.voxel_query:
|
| 540 |
-
self.input_proj_q = nn.Linear(self.fourier_embedder.out_dim, width)
|
| 541 |
-
self.input_proj_kv = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
| 542 |
-
else:
|
| 543 |
-
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
| 544 |
-
self.cross_attn = ResidualCrossAttentionBlock(
|
| 545 |
-
width=width,
|
| 546 |
-
heads=heads,
|
| 547 |
-
qkv_bias=qkv_bias,
|
| 548 |
-
qk_norm=qk_norm
|
| 549 |
-
)
|
| 550 |
-
|
| 551 |
-
self.self_attn = None
|
| 552 |
-
if layers > 0:
|
| 553 |
-
self.self_attn = Transformer(
|
| 554 |
-
n_ctx=num_latents,
|
| 555 |
-
width=width,
|
| 556 |
-
layers=layers,
|
| 557 |
-
heads=heads,
|
| 558 |
-
qkv_bias=qkv_bias,
|
| 559 |
-
qk_norm=qk_norm
|
| 560 |
-
)
|
| 561 |
-
|
| 562 |
-
if use_ln_post:
|
| 563 |
-
self.ln_post = nn.LayerNorm(width)
|
| 564 |
-
else:
|
| 565 |
-
self.ln_post = None
|
| 566 |
-
|
| 567 |
-
def sample_points_and_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
| 568 |
-
B, N, D = pc.shape
|
| 569 |
-
num_pts = self.num_latents * self.downsample_ratio
|
| 570 |
-
|
| 571 |
-
# Compute number of latents
|
| 572 |
-
num_latents = int(num_pts / self.downsample_ratio)
|
| 573 |
-
|
| 574 |
-
# Compute the number of random and sharpedge latents
|
| 575 |
-
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
| 576 |
-
num_sharpedge_query = num_latents - num_random_query
|
| 577 |
-
|
| 578 |
-
# Split random and sharpedge surface points
|
| 579 |
-
random_pc, sharpedge_pc = torch.split(pc, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
| 580 |
-
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
|
| 581 |
-
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
|
| 582 |
-
|
| 583 |
-
# Randomly select random surface points and random query points
|
| 584 |
-
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
| 585 |
-
random_query_ratio = num_random_query / input_random_pc_size
|
| 586 |
-
idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[:input_random_pc_size]
|
| 587 |
-
input_random_pc = random_pc[:, idx_random_pc, :]
|
| 588 |
-
|
| 589 |
-
if self.voxel_query:
|
| 590 |
-
query_random_pc, query_voxel_indices = voxelize_from_point(pc, num_latents, resolution=self.voxel_query_res)
|
| 591 |
-
else:
|
| 592 |
-
flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
|
| 593 |
-
N_down = int(flatten_input_random_pc.shape[0] / B)
|
| 594 |
-
batch_down = torch.arange(B).to(pc.device)
|
| 595 |
-
batch_down = torch.repeat_interleave(batch_down, N_down)
|
| 596 |
-
idx_query_random = fps(flatten_input_random_pc, batch_down, ratio=random_query_ratio)
|
| 597 |
-
query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
|
| 598 |
-
|
| 599 |
-
# Randomly select sharpedge surface points and sharpedge query points
|
| 600 |
-
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
| 601 |
-
if input_sharpedge_pc_size == 0 or self.voxel_query:
|
| 602 |
-
input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(pc.device)
|
| 603 |
-
query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(pc.device)
|
| 604 |
-
else:
|
| 605 |
-
sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
|
| 606 |
-
idx_sharpedge_pc = torch.randperm(sharpedge_pc.shape[1], device=sharpedge_pc.device)[
|
| 607 |
-
:input_sharpedge_pc_size]
|
| 608 |
-
input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
|
| 609 |
-
flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(B * input_sharpedge_pc_size, D)
|
| 610 |
-
N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
|
| 611 |
-
batch_down = torch.arange(B).to(pc.device)
|
| 612 |
-
batch_down = torch.repeat_interleave(batch_down, N_down)
|
| 613 |
-
idx_query_sharpedge = fps(flatten_input_sharpedge_surface_points, batch_down, ratio=sharpedge_query_ratio)
|
| 614 |
-
query_sharpedge_pc = flatten_input_sharpedge_surface_points[idx_query_sharpedge].view(B, -1, D)
|
| 615 |
-
|
| 616 |
-
# Concatenate random and sharpedge surface points and query points
|
| 617 |
-
query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
|
| 618 |
-
input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
|
| 619 |
-
|
| 620 |
-
if self.jitter_query:
|
| 621 |
-
R = self.voxel_query_res // 2
|
| 622 |
-
noise = torch.rand_like(query_pc)
|
| 623 |
-
query_pc += (noise - 0.5) / R
|
| 624 |
-
|
| 625 |
-
# PE
|
| 626 |
-
query = self.fourier_embedder(query_pc)
|
| 627 |
-
data = self.fourier_embedder(input_pc)
|
| 628 |
-
|
| 629 |
-
# Concat normal if given
|
| 630 |
-
if self.point_feats != 0:
|
| 631 |
-
|
| 632 |
-
random_surface_feats, sharpedge_surface_feats = torch.split(feats, [self.pc_size, self.pc_sharpedge_size],
|
| 633 |
-
dim=1)
|
| 634 |
-
input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
|
| 635 |
-
if not self.voxel_query and not self.jitter_query:
|
| 636 |
-
flatten_input_random_surface_feats = input_random_surface_feats.view(B * input_random_pc_size, -1)
|
| 637 |
-
query_random_feats = flatten_input_random_surface_feats[idx_query_random].view(B, -1,
|
| 638 |
-
flatten_input_random_surface_feats.shape[
|
| 639 |
-
-1])
|
| 640 |
-
|
| 641 |
-
if input_sharpedge_pc_size == 0:
|
| 642 |
-
input_sharpedge_surface_feats = torch.zeros(B, 0, self.point_feats,
|
| 643 |
-
dtype=input_random_surface_feats.dtype).to(pc.device)
|
| 644 |
-
if not self.voxel_query and not self.jitter_query:
|
| 645 |
-
query_sharpedge_feats = torch.zeros(B, 0, self.point_feats, dtype=query_random_feats.dtype).to(
|
| 646 |
-
pc.device)
|
| 647 |
-
else:
|
| 648 |
-
input_sharpedge_surface_feats = sharpedge_surface_feats[:, idx_sharpedge_pc, :]
|
| 649 |
-
if not self.voxel_query and not self.jitter_query:
|
| 650 |
-
flatten_input_sharpedge_surface_feats = input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size,
|
| 651 |
-
-1)
|
| 652 |
-
query_sharpedge_feats = flatten_input_sharpedge_surface_feats[idx_query_sharpedge].view(B, -1,
|
| 653 |
-
flatten_input_sharpedge_surface_feats.shape[
|
| 654 |
-
-1])
|
| 655 |
-
if not self.voxel_query and not self.jitter_query:
|
| 656 |
-
query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
|
| 657 |
-
input_feats = torch.cat([input_random_surface_feats, input_sharpedge_surface_feats], dim=1)
|
| 658 |
-
|
| 659 |
-
if self.normal_pe:
|
| 660 |
-
if not self.voxel_query and not self.jitter_query:
|
| 661 |
-
query_normal_pe = self.fourier_embedder(query_feats[..., :3])
|
| 662 |
-
query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
|
| 663 |
-
input_normal_pe = self.fourier_embedder(input_feats[..., :3])
|
| 664 |
-
input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
|
| 665 |
-
|
| 666 |
-
if not self.voxel_query and not self.jitter_query:
|
| 667 |
-
query = torch.cat([query, query_feats], dim=-1)
|
| 668 |
-
data = torch.cat([data, input_feats], dim=-1)
|
| 669 |
-
|
| 670 |
-
if input_sharpedge_pc_size == 0:
|
| 671 |
-
query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
| 672 |
-
input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
| 673 |
-
|
| 674 |
-
if self.voxel_query:
|
| 675 |
-
pc_infos = [query_voxel_indices, query_random_pc]
|
| 676 |
-
else:
|
| 677 |
-
pc_infos = [query_pc, input_pc, query_random_pc, input_random_pc, query_sharpedge_pc, input_sharpedge_pc]
|
| 678 |
-
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]), pc_infos
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
def forward(self, pc, feats):
|
| 682 |
-
"""
|
| 683 |
-
|
| 684 |
-
Args:
|
| 685 |
-
pc (torch.FloatTensor): [B, N, 3]
|
| 686 |
-
feats (torch.FloatTensor or None): [B, N, C]
|
| 687 |
-
|
| 688 |
-
Returns:
|
| 689 |
-
|
| 690 |
-
"""
|
| 691 |
-
query, data, pc_infos = self.sample_points_and_latents(pc, feats)
|
| 692 |
-
|
| 693 |
-
if self.jitter_query or self.voxel_query:
|
| 694 |
-
query = self.input_proj_q(query)
|
| 695 |
-
query = query
|
| 696 |
-
data = self.input_proj_kv(data)
|
| 697 |
-
data = data
|
| 698 |
-
else:
|
| 699 |
-
query = self.input_proj(query)
|
| 700 |
-
query = query
|
| 701 |
-
data = self.input_proj(data)
|
| 702 |
-
data = data
|
| 703 |
-
|
| 704 |
-
latents = self.cross_attn(query, data)
|
| 705 |
-
if self.self_attn is not None:
|
| 706 |
-
latents = self.self_attn(latents)
|
| 707 |
-
|
| 708 |
-
if self.ln_post is not None:
|
| 709 |
-
latents = self.ln_post(latents)
|
| 710 |
-
|
| 711 |
-
return latents, pc_infos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/attention_processors.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
import os
|
| 23 |
-
|
| 24 |
-
import torch
|
| 25 |
-
import torch.nn.functional as F
|
| 26 |
-
|
| 27 |
-
scaled_dot_product_attention = F.scaled_dot_product_attention
|
| 28 |
-
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
|
| 29 |
-
try:
|
| 30 |
-
from sageattention import sageattn
|
| 31 |
-
except ImportError:
|
| 32 |
-
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
| 33 |
-
scaled_dot_product_attention = sageattn
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class CrossAttentionProcessor:
|
| 37 |
-
def __call__(self, attn, q, k, v):
|
| 38 |
-
out = scaled_dot_product_attention(q, k, v)
|
| 39 |
-
return out
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class FlashVDMCrossAttentionProcessor:
|
| 43 |
-
def __init__(self, topk=None):
|
| 44 |
-
self.topk = topk
|
| 45 |
-
|
| 46 |
-
def __call__(self, attn, q, k, v):
|
| 47 |
-
if k.shape[-2] == 3072:
|
| 48 |
-
topk = 1024
|
| 49 |
-
elif k.shape[-2] == 512:
|
| 50 |
-
topk = 256
|
| 51 |
-
else:
|
| 52 |
-
topk = k.shape[-2] // 3
|
| 53 |
-
|
| 54 |
-
if self.topk is True:
|
| 55 |
-
q1 = q[:, :, ::100, :]
|
| 56 |
-
sim = q1 @ k.transpose(-1, -2)
|
| 57 |
-
sim = torch.mean(sim, -2)
|
| 58 |
-
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
| 59 |
-
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
| 60 |
-
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
| 61 |
-
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
| 62 |
-
out = scaled_dot_product_attention(q, k0, v0)
|
| 63 |
-
elif self.topk is False:
|
| 64 |
-
out = scaled_dot_product_attention(q, k, v)
|
| 65 |
-
else:
|
| 66 |
-
idx, counts = self.topk
|
| 67 |
-
start = 0
|
| 68 |
-
outs = []
|
| 69 |
-
for grid_coord, count in zip(idx, counts):
|
| 70 |
-
end = start + count
|
| 71 |
-
q_chunk = q[:, :, start:end, :]
|
| 72 |
-
k0, v0 = self.select_topkv(q_chunk, k, v, topk)
|
| 73 |
-
out = scaled_dot_product_attention(q_chunk, k0, v0)
|
| 74 |
-
outs.append(out)
|
| 75 |
-
start += count
|
| 76 |
-
out = torch.cat(outs, dim=-2)
|
| 77 |
-
self.topk = False
|
| 78 |
-
return out
|
| 79 |
-
|
| 80 |
-
def select_topkv(self, q_chunk, k, v, topk):
|
| 81 |
-
q1 = q_chunk[:, :, ::50, :]
|
| 82 |
-
sim = q1 @ k.transpose(-1, -2)
|
| 83 |
-
sim = torch.mean(sim, -2)
|
| 84 |
-
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
| 85 |
-
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
| 86 |
-
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
| 87 |
-
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
| 88 |
-
return k0, v0
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class FlashVDMTopMCrossAttentionProcessor(FlashVDMCrossAttentionProcessor):
|
| 92 |
-
def select_topkv(self, q_chunk, k, v, topk):
|
| 93 |
-
q1 = q_chunk[:, :, ::30, :]
|
| 94 |
-
sim = q1 @ k.transpose(-1, -2)
|
| 95 |
-
# sim = sim.to(torch.float32)
|
| 96 |
-
sim = sim.softmax(-1)
|
| 97 |
-
sim = torch.mean(sim, 1)
|
| 98 |
-
activated_token = torch.where(sim > 1e-6)[2]
|
| 99 |
-
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
| 100 |
-
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
|
| 101 |
-
v0 = torch.gather(v, dim=-2, index=index)
|
| 102 |
-
k0 = torch.gather(k, dim=-2, index=index)
|
| 103 |
-
return k0, v0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/model.py
DELETED
|
@@ -1,377 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Open Source Model Licensed under the Apache License Version 2.0
|
| 9 |
-
# and Other Licenses of the Third-Party Components therein:
|
| 10 |
-
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 11 |
-
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 12 |
-
|
| 13 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 14 |
-
# The below software and/or models in this distribution may have been
|
| 15 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
| 16 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 17 |
-
|
| 18 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 19 |
-
# except for the third-party components listed below.
|
| 20 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 21 |
-
# in the repsective licenses of these third-party components.
|
| 22 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 23 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 24 |
-
# all relevant laws and regulations.
|
| 25 |
-
|
| 26 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 27 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 28 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 29 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 30 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 31 |
-
|
| 32 |
-
import os
|
| 33 |
-
from typing import Union, List
|
| 34 |
-
|
| 35 |
-
import numpy as np
|
| 36 |
-
import torch
|
| 37 |
-
import torch.nn as nn
|
| 38 |
-
import yaml
|
| 39 |
-
|
| 40 |
-
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder, PointCrossAttentionEncoder
|
| 41 |
-
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
| 42 |
-
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
|
| 43 |
-
from ...utils import logger, synchronize_timer, smart_load_model
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class DiagonalGaussianDistribution(object):
|
| 47 |
-
def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
|
| 48 |
-
"""
|
| 49 |
-
Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
|
| 50 |
-
|
| 51 |
-
Args:
|
| 52 |
-
parameters (Union[torch.Tensor, List[torch.Tensor]]):
|
| 53 |
-
Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
|
| 54 |
-
or a list of two tensors [mean, logvar].
|
| 55 |
-
deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
|
| 56 |
-
Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
|
| 57 |
-
concatenated if parameters is a single tensor. Default is 1.
|
| 58 |
-
"""
|
| 59 |
-
self.feat_dim = feat_dim
|
| 60 |
-
self.parameters = parameters
|
| 61 |
-
|
| 62 |
-
if isinstance(parameters, list):
|
| 63 |
-
self.mean = parameters[0]
|
| 64 |
-
self.logvar = parameters[1]
|
| 65 |
-
else:
|
| 66 |
-
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
| 67 |
-
|
| 68 |
-
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 69 |
-
self.deterministic = deterministic
|
| 70 |
-
self.std = torch.exp(0.5 * self.logvar)
|
| 71 |
-
self.var = torch.exp(self.logvar)
|
| 72 |
-
if self.deterministic:
|
| 73 |
-
self.var = self.std = torch.zeros_like(self.mean)
|
| 74 |
-
|
| 75 |
-
def sample(self):
|
| 76 |
-
"""
|
| 77 |
-
Sample from the diagonal Gaussian distribution.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
torch.Tensor: A sample tensor with the same shape as the mean.
|
| 81 |
-
"""
|
| 82 |
-
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 83 |
-
return x
|
| 84 |
-
|
| 85 |
-
def kl(self, other=None, dims=(1, 2)):
|
| 86 |
-
"""
|
| 87 |
-
Compute the Kullback-Leibler (KL) divergence between this distribution and another.
|
| 88 |
-
|
| 89 |
-
If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
|
| 90 |
-
|
| 91 |
-
Args:
|
| 92 |
-
other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
|
| 93 |
-
dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
|
| 94 |
-
Default is (1, 2, 3).
|
| 95 |
-
|
| 96 |
-
Returns:
|
| 97 |
-
torch.Tensor: The mean KL divergence value.
|
| 98 |
-
"""
|
| 99 |
-
if self.deterministic:
|
| 100 |
-
return torch.Tensor([0.])
|
| 101 |
-
else:
|
| 102 |
-
if other is None:
|
| 103 |
-
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
|
| 104 |
-
else:
|
| 105 |
-
return 0.5 * torch.mean(
|
| 106 |
-
torch.pow(self.mean - other.mean, 2) / other.var
|
| 107 |
-
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 108 |
-
dim=dims)
|
| 109 |
-
|
| 110 |
-
def nll(self, sample, dims=(1, 2, 3)):
|
| 111 |
-
if self.deterministic:
|
| 112 |
-
return torch.Tensor([0.])
|
| 113 |
-
logtwopi = np.log(2.0 * np.pi)
|
| 114 |
-
return 0.5 * torch.sum(
|
| 115 |
-
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 116 |
-
dim=dims)
|
| 117 |
-
|
| 118 |
-
def mode(self):
|
| 119 |
-
return self.mean
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
class VectsetVAE(nn.Module):
|
| 123 |
-
|
| 124 |
-
@classmethod
|
| 125 |
-
@synchronize_timer('VectsetVAE Model Loading')
|
| 126 |
-
def from_single_file(
|
| 127 |
-
cls,
|
| 128 |
-
ckpt_path,
|
| 129 |
-
config_path=None,
|
| 130 |
-
params=None,
|
| 131 |
-
device='cuda',
|
| 132 |
-
dtype=torch.float16,
|
| 133 |
-
use_safetensors=None,
|
| 134 |
-
**kwargs,
|
| 135 |
-
):
|
| 136 |
-
# load config
|
| 137 |
-
with open(config_path, 'r') as f:
|
| 138 |
-
config = yaml.safe_load(f)
|
| 139 |
-
|
| 140 |
-
# load ckpt
|
| 141 |
-
if use_safetensors:
|
| 142 |
-
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
| 143 |
-
if not os.path.exists(ckpt_path):
|
| 144 |
-
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
| 145 |
-
|
| 146 |
-
logger.info(f"Loading model from {ckpt_path}")
|
| 147 |
-
if use_safetensors:
|
| 148 |
-
import safetensors.torch
|
| 149 |
-
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
| 150 |
-
else:
|
| 151 |
-
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
| 152 |
-
|
| 153 |
-
if params is not None:
|
| 154 |
-
model_kwargs = params
|
| 155 |
-
else:
|
| 156 |
-
model_kwargs = config['params']
|
| 157 |
-
model_kwargs.update(kwargs)
|
| 158 |
-
|
| 159 |
-
model = cls(**model_kwargs)
|
| 160 |
-
model.load_state_dict(ckpt)
|
| 161 |
-
|
| 162 |
-
model.to(device=device, dtype=dtype)
|
| 163 |
-
return model
|
| 164 |
-
|
| 165 |
-
@classmethod
|
| 166 |
-
def from_pretrained(
|
| 167 |
-
cls,
|
| 168 |
-
model_path,
|
| 169 |
-
device='cuda',
|
| 170 |
-
params=None,
|
| 171 |
-
dtype=torch.float16,
|
| 172 |
-
use_safetensors=False,
|
| 173 |
-
variant='fp16',
|
| 174 |
-
subfolder='hunyuan3d-vae-v2-1',
|
| 175 |
-
**kwargs,
|
| 176 |
-
):
|
| 177 |
-
config_path, ckpt_path = smart_load_model(
|
| 178 |
-
model_path,
|
| 179 |
-
subfolder=subfolder,
|
| 180 |
-
use_safetensors=use_safetensors,
|
| 181 |
-
variant=variant
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
return cls.from_single_file(
|
| 185 |
-
ckpt_path,
|
| 186 |
-
config_path=config_path,
|
| 187 |
-
params=params,
|
| 188 |
-
device=device,
|
| 189 |
-
dtype=dtype,
|
| 190 |
-
use_safetensors=use_safetensors,
|
| 191 |
-
**kwargs
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
def init_from_ckpt(self, path, ignore_keys=()):
|
| 195 |
-
state_dict = torch.load(path, map_location="cpu")
|
| 196 |
-
state_dict = state_dict.get("state_dict", state_dict)
|
| 197 |
-
keys = list(state_dict.keys())
|
| 198 |
-
for k in keys:
|
| 199 |
-
for ik in ignore_keys:
|
| 200 |
-
if k.startswith(ik):
|
| 201 |
-
print("Deleting key {} from state_dict.".format(k))
|
| 202 |
-
del state_dict[k]
|
| 203 |
-
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
| 204 |
-
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 205 |
-
if len(missing) > 0:
|
| 206 |
-
print(f"Missing Keys: {missing}")
|
| 207 |
-
print(f"Unexpected Keys: {unexpected}")
|
| 208 |
-
|
| 209 |
-
def __init__(
|
| 210 |
-
self,
|
| 211 |
-
volume_decoder=None,
|
| 212 |
-
surface_extractor=None
|
| 213 |
-
):
|
| 214 |
-
super().__init__()
|
| 215 |
-
if volume_decoder is None:
|
| 216 |
-
volume_decoder = VanillaVolumeDecoder()
|
| 217 |
-
if surface_extractor is None:
|
| 218 |
-
surface_extractor = MCSurfaceExtractor()
|
| 219 |
-
self.volume_decoder = volume_decoder
|
| 220 |
-
self.surface_extractor = surface_extractor
|
| 221 |
-
|
| 222 |
-
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
| 223 |
-
with synchronize_timer('Volume decoding'):
|
| 224 |
-
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
|
| 225 |
-
with synchronize_timer('Surface extraction'):
|
| 226 |
-
outputs = self.surface_extractor(grid_logits, **kwargs)
|
| 227 |
-
return outputs, grid_logits
|
| 228 |
-
|
| 229 |
-
def enable_flashvdm_decoder(
|
| 230 |
-
self,
|
| 231 |
-
enabled: bool = True,
|
| 232 |
-
adaptive_kv_selection=True,
|
| 233 |
-
topk_mode='mean',
|
| 234 |
-
mc_algo='mc',
|
| 235 |
-
):
|
| 236 |
-
if enabled:
|
| 237 |
-
if adaptive_kv_selection:
|
| 238 |
-
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
|
| 239 |
-
else:
|
| 240 |
-
self.volume_decoder = HierarchicalVolumeDecoding()
|
| 241 |
-
if mc_algo not in SurfaceExtractors.keys():
|
| 242 |
-
raise ValueError(f'Unsupported mc_algo {mc_algo}, available:{list(SurfaceExtractors.keys())}')
|
| 243 |
-
self.surface_extractor = SurfaceExtractors[mc_algo]()
|
| 244 |
-
else:
|
| 245 |
-
self.volume_decoder = VanillaVolumeDecoder()
|
| 246 |
-
self.surface_extractor = MCSurfaceExtractor()
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
class ShapeVAE(VectsetVAE):
|
| 250 |
-
def __init__(
|
| 251 |
-
self,
|
| 252 |
-
*,
|
| 253 |
-
num_latents: int,
|
| 254 |
-
embed_dim: int,
|
| 255 |
-
width: int,
|
| 256 |
-
heads: int,
|
| 257 |
-
num_decoder_layers: int,
|
| 258 |
-
num_encoder_layers: int = 8,
|
| 259 |
-
pc_size: int = 5120,
|
| 260 |
-
pc_sharpedge_size: int = 5120,
|
| 261 |
-
point_feats: int = 3,
|
| 262 |
-
downsample_ratio: int = 20,
|
| 263 |
-
geo_decoder_downsample_ratio: int = 1,
|
| 264 |
-
geo_decoder_mlp_expand_ratio: int = 4,
|
| 265 |
-
geo_decoder_ln_post: bool = True,
|
| 266 |
-
num_freqs: int = 8,
|
| 267 |
-
include_pi: bool = True,
|
| 268 |
-
qkv_bias: bool = True,
|
| 269 |
-
qk_norm: bool = False,
|
| 270 |
-
label_type: str = "binary",
|
| 271 |
-
drop_path_rate: float = 0.0,
|
| 272 |
-
scale_factor: float = 1.0,
|
| 273 |
-
use_ln_post: bool = True,
|
| 274 |
-
enable_flashvdm: bool = False,
|
| 275 |
-
ckpt_path = None,
|
| 276 |
-
jitter_query: bool = False,
|
| 277 |
-
voxel_query: bool = False,
|
| 278 |
-
voxel_query_res: int = 128,
|
| 279 |
-
):
|
| 280 |
-
super().__init__()
|
| 281 |
-
self.geo_decoder_ln_post = geo_decoder_ln_post
|
| 282 |
-
self.downsample_ratio = downsample_ratio
|
| 283 |
-
|
| 284 |
-
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
| 285 |
-
|
| 286 |
-
self.encoder = PointCrossAttentionEncoder(
|
| 287 |
-
fourier_embedder=self.fourier_embedder,
|
| 288 |
-
num_latents=num_latents,
|
| 289 |
-
downsample_ratio=self.downsample_ratio,
|
| 290 |
-
pc_size=pc_size,
|
| 291 |
-
pc_sharpedge_size=pc_sharpedge_size,
|
| 292 |
-
point_feats=point_feats,
|
| 293 |
-
width=width,
|
| 294 |
-
heads=heads,
|
| 295 |
-
layers=num_encoder_layers,
|
| 296 |
-
qkv_bias=qkv_bias,
|
| 297 |
-
use_ln_post=use_ln_post,
|
| 298 |
-
qk_norm=qk_norm,
|
| 299 |
-
jitter_query=jitter_query,
|
| 300 |
-
voxel_query=voxel_query,
|
| 301 |
-
voxel_query_res=voxel_query_res
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
self.pre_kl = nn.Linear(width, embed_dim * 2)
|
| 305 |
-
self.post_kl = nn.Linear(embed_dim, width)
|
| 306 |
-
|
| 307 |
-
self.transformer = Transformer(
|
| 308 |
-
n_ctx=num_latents,
|
| 309 |
-
width=width,
|
| 310 |
-
layers=num_decoder_layers,
|
| 311 |
-
heads=heads,
|
| 312 |
-
qkv_bias=qkv_bias,
|
| 313 |
-
qk_norm=qk_norm,
|
| 314 |
-
drop_path_rate=drop_path_rate
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
self.geo_decoder = CrossAttentionDecoder(
|
| 318 |
-
fourier_embedder=self.fourier_embedder,
|
| 319 |
-
out_channels=1,
|
| 320 |
-
num_latents=num_latents,
|
| 321 |
-
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
| 322 |
-
downsample_ratio=geo_decoder_downsample_ratio,
|
| 323 |
-
enable_ln_post=self.geo_decoder_ln_post,
|
| 324 |
-
width=width // geo_decoder_downsample_ratio,
|
| 325 |
-
heads=heads // geo_decoder_downsample_ratio,
|
| 326 |
-
qkv_bias=qkv_bias,
|
| 327 |
-
qk_norm=qk_norm,
|
| 328 |
-
label_type=label_type,
|
| 329 |
-
)
|
| 330 |
-
|
| 331 |
-
self.scale_factor = scale_factor
|
| 332 |
-
self.latent_shape = (num_latents, embed_dim)
|
| 333 |
-
|
| 334 |
-
if ckpt_path is not None:
|
| 335 |
-
self.init_from_ckpt(ckpt_path)
|
| 336 |
-
|
| 337 |
-
if enable_flashvdm:
|
| 338 |
-
self.enable_flashvdm_decoder()
|
| 339 |
-
|
| 340 |
-
def forward(self, latents):
|
| 341 |
-
latents = self.post_kl(latents)
|
| 342 |
-
latents = self.transformer(latents)
|
| 343 |
-
return latents
|
| 344 |
-
|
| 345 |
-
def encode(self, surface, sample_posterior=True, need_kl=False, need_voxel=False):
|
| 346 |
-
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
| 347 |
-
latents, pc_infos = self.encoder(pc, feats)
|
| 348 |
-
# print(latents.shape, self.pre_kl.weight.shape)
|
| 349 |
-
moments = self.pre_kl(latents)
|
| 350 |
-
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
| 351 |
-
if sample_posterior:
|
| 352 |
-
latents = posterior.sample()
|
| 353 |
-
else:
|
| 354 |
-
latents = posterior.mode()
|
| 355 |
-
if need_kl:
|
| 356 |
-
return latents, posterior
|
| 357 |
-
if need_voxel:
|
| 358 |
-
return latents, pc_infos[0]
|
| 359 |
-
return latents
|
| 360 |
-
|
| 361 |
-
def decode(self, latents, voxel_idx=None):
|
| 362 |
-
latents = self.post_kl(latents)
|
| 363 |
-
latents = self.transformer(latents)
|
| 364 |
-
return latents
|
| 365 |
-
|
| 366 |
-
def query(self, latents, queries, voxel_idx=None):
|
| 367 |
-
"""
|
| 368 |
-
Args:
|
| 369 |
-
queries (torch.FloatTensor): [B, N, 3]
|
| 370 |
-
latents (torch.FloatTensor): [B, embed_dim]
|
| 371 |
-
|
| 372 |
-
Returns:
|
| 373 |
-
logits (torch.FloatTensor): [B, N], occupancy logits
|
| 374 |
-
"""
|
| 375 |
-
logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)
|
| 376 |
-
|
| 377 |
-
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/surface_extractors.py
DELETED
|
@@ -1,266 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
from typing import Union, Tuple, List
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
import torch
|
| 26 |
-
from skimage import measure
|
| 27 |
-
import cubvh
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class Latent2MeshOutput:
|
| 31 |
-
def __init__(self, mesh_v=None, mesh_f=None):
|
| 32 |
-
self.mesh_v = mesh_v
|
| 33 |
-
self.mesh_f = mesh_f
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def center_vertices(vertices):
|
| 37 |
-
"""Translate the vertices so that bounding box is centered at zero."""
|
| 38 |
-
vert_min = vertices.min(dim=0)[0]
|
| 39 |
-
vert_max = vertices.max(dim=0)[0]
|
| 40 |
-
vert_center = 0.5 * (vert_min + vert_max)
|
| 41 |
-
return vertices - vert_center
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class SurfaceExtractor:
|
| 45 |
-
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
|
| 46 |
-
"""
|
| 47 |
-
Compute grid size, bounding box minimum coordinates, and bounding box size based on input
|
| 48 |
-
bounds and resolution.
|
| 49 |
-
|
| 50 |
-
Args:
|
| 51 |
-
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
|
| 52 |
-
float representing half side length.
|
| 53 |
-
If float, bounds are assumed symmetric around zero in all axes.
|
| 54 |
-
Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
|
| 55 |
-
octree_resolution (int): Resolution of the octree grid.
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
|
| 59 |
-
bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
|
| 60 |
-
bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
|
| 61 |
-
"""
|
| 62 |
-
if isinstance(bounds, float):
|
| 63 |
-
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 64 |
-
|
| 65 |
-
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
| 66 |
-
bbox_size = bbox_max - bbox_min
|
| 67 |
-
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
| 68 |
-
return grid_size, bbox_min, bbox_size
|
| 69 |
-
|
| 70 |
-
def run(self, *args, **kwargs):
|
| 71 |
-
"""
|
| 72 |
-
Abstract method to extract surface mesh from grid logits.
|
| 73 |
-
|
| 74 |
-
This method should be implemented by subclasses.
|
| 75 |
-
|
| 76 |
-
Raises:
|
| 77 |
-
NotImplementedError: Always, since this is an abstract method.
|
| 78 |
-
"""
|
| 79 |
-
return NotImplementedError
|
| 80 |
-
|
| 81 |
-
def __call__(self, grid_logits, **kwargs):
|
| 82 |
-
"""
|
| 83 |
-
Process a batch of grid logits to extract surface meshes.
|
| 84 |
-
|
| 85 |
-
Args:
|
| 86 |
-
grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
|
| 87 |
-
**kwargs: Additional keyword arguments passed to the `run` method.
|
| 88 |
-
|
| 89 |
-
Returns:
|
| 90 |
-
List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
|
| 91 |
-
If extraction fails for a grid, None is appended at that position.
|
| 92 |
-
"""
|
| 93 |
-
outputs = []
|
| 94 |
-
for i in range(grid_logits.shape[0]):
|
| 95 |
-
try:
|
| 96 |
-
vertices, faces = self.run(grid_logits[i], **kwargs)
|
| 97 |
-
vertices = vertices.astype(np.float32)
|
| 98 |
-
faces = np.ascontiguousarray(faces)
|
| 99 |
-
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
|
| 100 |
-
|
| 101 |
-
except Exception:
|
| 102 |
-
import traceback
|
| 103 |
-
traceback.print_exc()
|
| 104 |
-
outputs.append(None)
|
| 105 |
-
|
| 106 |
-
return outputs
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def get_sparse_valid_voxels(grid_logit: torch.Tensor):
|
| 110 |
-
|
| 111 |
-
if not isinstance(grid_logit, torch.Tensor):
|
| 112 |
-
raise TypeError("Input must be a PyTorch tensor.")
|
| 113 |
-
if grid_logit.dim() != 3 or grid_logit.shape[0] != grid_logit.shape[1] or grid_logit.shape[0] != grid_logit.shape[2]:
|
| 114 |
-
raise ValueError("Input tensor must have shape (N, N, N)")
|
| 115 |
-
|
| 116 |
-
N = grid_logit.shape[0]
|
| 117 |
-
device = grid_logit.device
|
| 118 |
-
|
| 119 |
-
# Chunk processing to save memory
|
| 120 |
-
chunk_size = 128
|
| 121 |
-
|
| 122 |
-
all_sparse_coords = []
|
| 123 |
-
all_sparse_logits = []
|
| 124 |
-
|
| 125 |
-
# Process in chunks along x-axis
|
| 126 |
-
for start_x in range(0, N - 1, chunk_size):
|
| 127 |
-
end_x = min(start_x + chunk_size, N - 1)
|
| 128 |
-
|
| 129 |
-
# Determine slice range including +1 for neighbor checks
|
| 130 |
-
# slice_end needs to be end_x + 1 to include the neighbors for the last voxel in chunk
|
| 131 |
-
slice_end = end_x + 1
|
| 132 |
-
|
| 133 |
-
chunk = grid_logit[start_x:slice_end, :, :]
|
| 134 |
-
nan_mask = torch.isnan(chunk)
|
| 135 |
-
|
| 136 |
-
# Compute mask for this chunk (valid voxels are 0 to end_x - start_x)
|
| 137 |
-
# Note: chunk shape is [D_chunk, N, N].
|
| 138 |
-
# We want to check validity for [0..D_chunk-1, :-1, :-1]
|
| 139 |
-
|
| 140 |
-
sub_nan_mask = nan_mask
|
| 141 |
-
|
| 142 |
-
# Validity check requires looking at i and i+1
|
| 143 |
-
# Invalid if ANY corner is NaN
|
| 144 |
-
invalid_voxel_mask = (
|
| 145 |
-
sub_nan_mask[:-1, :-1, :-1] |
|
| 146 |
-
sub_nan_mask[1:, :-1, :-1] |
|
| 147 |
-
sub_nan_mask[:-1, 1:, :-1] |
|
| 148 |
-
sub_nan_mask[:-1, :-1, 1:] |
|
| 149 |
-
sub_nan_mask[:-1, 1:, 1:] |
|
| 150 |
-
sub_nan_mask[1:, :-1, 1:] |
|
| 151 |
-
sub_nan_mask[1:, 1:, :-1] |
|
| 152 |
-
sub_nan_mask[1:, 1:, 1:]
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
valid_voxel_mask = ~invalid_voxel_mask
|
| 156 |
-
|
| 157 |
-
# Get local coordinates
|
| 158 |
-
local_coords = valid_voxel_mask.nonzero(as_tuple=False)
|
| 159 |
-
|
| 160 |
-
if local_coords.shape[0] > 0:
|
| 161 |
-
lx, ly, lz = local_coords[:, 0], local_coords[:, 1], local_coords[:, 2]
|
| 162 |
-
|
| 163 |
-
# Extract logits using local indices on the chunk
|
| 164 |
-
# v0 is at lx, v1 is at lx+1, etc.
|
| 165 |
-
sparse_vertex_logits = torch.stack([
|
| 166 |
-
chunk[lx, ly, lz], # v0
|
| 167 |
-
chunk[lx + 1, ly, lz], # v1
|
| 168 |
-
chunk[lx + 1, ly + 1, lz], # v2
|
| 169 |
-
chunk[lx, ly + 1, lz], # v3
|
| 170 |
-
chunk[lx, ly, lz + 1], # v4
|
| 171 |
-
chunk[lx + 1, ly, lz + 1], # v5
|
| 172 |
-
chunk[lx + 1, ly + 1, lz + 1], # v6
|
| 173 |
-
chunk[lx, ly + 1, lz + 1] # v7
|
| 174 |
-
], dim=1)
|
| 175 |
-
|
| 176 |
-
# Convert local coords to global coords
|
| 177 |
-
# x coordinate needs offset added
|
| 178 |
-
global_coords = local_coords.clone()
|
| 179 |
-
global_coords[:, 0] += start_x
|
| 180 |
-
|
| 181 |
-
all_sparse_coords.append(global_coords)
|
| 182 |
-
all_sparse_logits.append(sparse_vertex_logits)
|
| 183 |
-
|
| 184 |
-
# Free memory
|
| 185 |
-
del chunk, nan_mask, invalid_voxel_mask, valid_voxel_mask, local_coords
|
| 186 |
-
|
| 187 |
-
if not all_sparse_coords:
|
| 188 |
-
return torch.empty((0, 3), dtype=torch.long, device=device), torch.empty((0, 8), dtype=grid_logit.dtype, device=device)
|
| 189 |
-
|
| 190 |
-
sparse_coords = torch.cat(all_sparse_coords, dim=0)
|
| 191 |
-
sparse_vertex_logits = torch.cat(all_sparse_logits, dim=0)
|
| 192 |
-
|
| 193 |
-
return sparse_coords, sparse_vertex_logits
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
class MCSurfaceExtractor(SurfaceExtractor):
|
| 197 |
-
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
|
| 198 |
-
"""
|
| 199 |
-
Extract surface mesh using the Marching Cubes algorithm.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
| 203 |
-
mc_level (float): The level (iso-value) at which to extract the surface.
|
| 204 |
-
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
|
| 205 |
-
octree_resolution (int): Resolution of the octree grid.
|
| 206 |
-
**kwargs: Additional keyword arguments (ignored).
|
| 207 |
-
|
| 208 |
-
Returns:
|
| 209 |
-
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
| 210 |
-
- vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
|
| 211 |
-
box coordinates.
|
| 212 |
-
- faces (np.ndarray): Extracted mesh faces (triangles).
|
| 213 |
-
"""
|
| 214 |
-
|
| 215 |
-
grid_logit = grid_logit.detach()
|
| 216 |
-
|
| 217 |
-
sparse_coords, sparse_logits = get_sparse_valid_voxels(grid_logit)
|
| 218 |
-
# Convert to float32 only for the sparse set
|
| 219 |
-
vertices, faces = cubvh.sparse_marching_cubes(sparse_coords, sparse_logits.float(), mc_level)
|
| 220 |
-
|
| 221 |
-
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
|
| 222 |
-
# vertices, faces, normals, _ = measure.marching_cubes(grid_logit,
|
| 223 |
-
# mc_level, method="lewiner", mask=(~np.isnan(grid_logit)))
|
| 224 |
-
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
|
| 225 |
-
vertices = vertices / grid_size * bbox_size + bbox_min
|
| 226 |
-
return vertices, faces
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class DMCSurfaceExtractor(SurfaceExtractor):
|
| 230 |
-
def run(self, grid_logit, *, octree_resolution, **kwargs):
|
| 231 |
-
"""
|
| 232 |
-
Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
|
| 233 |
-
|
| 234 |
-
Args:
|
| 235 |
-
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
| 236 |
-
octree_resolution (int): Resolution of the octree grid.
|
| 237 |
-
**kwargs: Additional keyword arguments (ignored).
|
| 238 |
-
|
| 239 |
-
Returns:
|
| 240 |
-
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
| 241 |
-
- vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
|
| 242 |
-
- faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
|
| 243 |
-
|
| 244 |
-
Raises:
|
| 245 |
-
ImportError: If the 'diso' package is not installed.
|
| 246 |
-
"""
|
| 247 |
-
device = grid_logit.device
|
| 248 |
-
if not hasattr(self, 'dmc'):
|
| 249 |
-
try:
|
| 250 |
-
from diso import DiffDMC
|
| 251 |
-
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
| 252 |
-
except:
|
| 253 |
-
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
| 254 |
-
sdf = -grid_logit / octree_resolution
|
| 255 |
-
sdf = sdf.to(torch.float32).contiguous()
|
| 256 |
-
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
| 257 |
-
verts = center_vertices(verts)
|
| 258 |
-
vertices = verts.detach().cpu().numpy()
|
| 259 |
-
faces = faces.detach().cpu().numpy()[:, ::-1]
|
| 260 |
-
return vertices, faces
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
SurfaceExtractors = {
|
| 264 |
-
'mc': MCSurfaceExtractor,
|
| 265 |
-
'dmc': DMCSurfaceExtractor,
|
| 266 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/vae_trainer.py
DELETED
|
@@ -1,229 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from contextlib import contextmanager
|
| 3 |
-
from typing import List, Tuple, Optional, Union
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from torch.optim import lr_scheduler
|
| 8 |
-
import pytorch_lightning as pl
|
| 9 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 10 |
-
from pytorch_lightning.utilities import rank_zero_only
|
| 11 |
-
import trimesh
|
| 12 |
-
|
| 13 |
-
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def export_to_trimesh(mesh_output):
|
| 17 |
-
if isinstance(mesh_output, list):
|
| 18 |
-
outputs = []
|
| 19 |
-
for mesh in mesh_output:
|
| 20 |
-
if mesh is None:
|
| 21 |
-
outputs.append(None)
|
| 22 |
-
else:
|
| 23 |
-
mesh.mesh_f = mesh.mesh_f[:, ::-1]
|
| 24 |
-
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
|
| 25 |
-
outputs.append(mesh_output)
|
| 26 |
-
return outputs
|
| 27 |
-
else:
|
| 28 |
-
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
|
| 29 |
-
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
|
| 30 |
-
return mesh_output
|
| 31 |
-
|
| 32 |
-
class VAETrainer(pl.LightningModule):
|
| 33 |
-
def __init__(
|
| 34 |
-
self,
|
| 35 |
-
*,
|
| 36 |
-
vae_config,
|
| 37 |
-
optimizer_cfg,
|
| 38 |
-
loss_cfg,
|
| 39 |
-
save_dir,
|
| 40 |
-
mc_res,
|
| 41 |
-
ckpt_path: Optional[str] = None,
|
| 42 |
-
ignore_keys: Union[Tuple[str], List[str]] = (),
|
| 43 |
-
torch_compile: bool = False,
|
| 44 |
-
):
|
| 45 |
-
super().__init__()
|
| 46 |
-
|
| 47 |
-
# ========= init optimizer config ========= #
|
| 48 |
-
self.optimizer_cfg = optimizer_cfg
|
| 49 |
-
self.loss_cfg = loss_cfg
|
| 50 |
-
self.ckpt_path = ckpt_path
|
| 51 |
-
self.vae_model = instantiate_vae_model(vae_config, requires_grad=True)
|
| 52 |
-
if ckpt_path is not None:
|
| 53 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 54 |
-
|
| 55 |
-
self.mc_res = mc_res
|
| 56 |
-
self.save_root = save_dir
|
| 57 |
-
if not os.path.exists(save_dir):
|
| 58 |
-
os.makedirs(save_dir)
|
| 59 |
-
|
| 60 |
-
# ========= torch compile to accelerate ========= #
|
| 61 |
-
self.torch_compile = torch_compile
|
| 62 |
-
if self.torch_compile:
|
| 63 |
-
torch.nn.Module.compile(self.vae_model)
|
| 64 |
-
print(f'*' * 100)
|
| 65 |
-
print(f'Compile model for acceleration')
|
| 66 |
-
print(f'*' * 100)
|
| 67 |
-
|
| 68 |
-
def init_from_ckpt(self, path, ignore_keys=()):
|
| 69 |
-
ckpt = torch.load(path, map_location="cpu")
|
| 70 |
-
if 'state_dict' not in ckpt:
|
| 71 |
-
# deepspeed ckpt
|
| 72 |
-
state_dict = {}
|
| 73 |
-
for k in ckpt.keys():
|
| 74 |
-
new_k = k.replace('_forward_module.', '')
|
| 75 |
-
state_dict[new_k] = ckpt[k]
|
| 76 |
-
else:
|
| 77 |
-
state_dict = ckpt["state_dict"]
|
| 78 |
-
|
| 79 |
-
keys = list(state_dict.keys())
|
| 80 |
-
for k in keys:
|
| 81 |
-
for ik in ignore_keys:
|
| 82 |
-
if ik in k:
|
| 83 |
-
print("Deleting key {} from state_dict.".format(k))
|
| 84 |
-
del state_dict[k]
|
| 85 |
-
|
| 86 |
-
# # ==================== Weight Surgery Start ====================
|
| 87 |
-
# old_key_base = "vae_model.encoder.input_proj"
|
| 88 |
-
# old_weight_key = f"{old_key_base}.weight"
|
| 89 |
-
# old_bias_key = f"{old_key_base}.bias"
|
| 90 |
-
|
| 91 |
-
# if old_weight_key in state_dict:
|
| 92 |
-
# print(f"[*] Detected legacy '{old_key_base}' in checkpoint. Performing weight surgery...")
|
| 93 |
-
|
| 94 |
-
# src_weight = state_dict[old_weight_key]
|
| 95 |
-
# src_bias = state_dict[old_bias_key]
|
| 96 |
-
|
| 97 |
-
# encoder = self.vae_model.encoder
|
| 98 |
-
# fourier_dim = encoder.fourier_embedder.out_dim
|
| 99 |
-
|
| 100 |
-
# # --- A. input_proj_kv ---
|
| 101 |
-
# # shape: [width, fourier_dim + point_feats]
|
| 102 |
-
# encoder.input_proj_kv.weight.data.copy_(src_weight)
|
| 103 |
-
# encoder.input_proj_kv.bias.data.copy_(src_bias)
|
| 104 |
-
# print(f" -> Loaded input_proj_kv from {old_key_base}")
|
| 105 |
-
|
| 106 |
-
# # --- B. input_proj_q ---
|
| 107 |
-
# # shape: [width, fourier_dim]
|
| 108 |
-
# sliced_weight = src_weight[:, :fourier_dim]
|
| 109 |
-
# encoder.input_proj_q.weight.data.copy_(sliced_weight)
|
| 110 |
-
# encoder.input_proj_q.bias.data.copy_(src_bias)
|
| 111 |
-
# print(f" -> Loaded input_proj_q (sliced) from {old_key_base}")
|
| 112 |
-
|
| 113 |
-
# del state_dict[old_weight_key]
|
| 114 |
-
# if old_bias_key in state_dict:
|
| 115 |
-
# del state_dict[old_bias_key]
|
| 116 |
-
# # ==================== Weight Surgery End ====================
|
| 117 |
-
|
| 118 |
-
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
| 119 |
-
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 120 |
-
if len(missing) > 0:
|
| 121 |
-
print(f"Missing Keys: {missing}")
|
| 122 |
-
print(f"Unexpected Keys: {unexpected}")
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def configure_optimizers(self) -> Tuple[List, List]:
|
| 126 |
-
lr = self.learning_rate
|
| 127 |
-
|
| 128 |
-
params_list = []
|
| 129 |
-
trainable_parameters = list(self.vae_model.parameters())
|
| 130 |
-
params_list.append({'params': trainable_parameters, 'lr': lr})
|
| 131 |
-
|
| 132 |
-
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
|
| 133 |
-
if hasattr(self.optimizer_cfg, 'scheduler'):
|
| 134 |
-
scheduler_func = instantiate_from_config(
|
| 135 |
-
self.optimizer_cfg.scheduler,
|
| 136 |
-
max_decay_steps=self.trainer.max_steps,
|
| 137 |
-
lr_max=lr
|
| 138 |
-
)
|
| 139 |
-
scheduler = {
|
| 140 |
-
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
| 141 |
-
"interval": "step",
|
| 142 |
-
"frequency": 1
|
| 143 |
-
}
|
| 144 |
-
schedulers = [scheduler]
|
| 145 |
-
else:
|
| 146 |
-
schedulers = []
|
| 147 |
-
optimizers = [optimizer]
|
| 148 |
-
|
| 149 |
-
return optimizers, schedulers
|
| 150 |
-
|
| 151 |
-
def on_train_epoch_start(self) -> None:
|
| 152 |
-
pl.seed_everything(self.trainer.global_rank)
|
| 153 |
-
|
| 154 |
-
def forward(self, batch):
|
| 155 |
-
sup_pc_s_list = [batch["sup_near_uniform"], batch["sup_near_sharp"], batch["sup_space"]]
|
| 156 |
-
rand_points = [sup_pc_s[:,:,:3] for sup_pc_s in sup_pc_s_list]
|
| 157 |
-
rand_points_val = [sup_pc_s[:,:,3:] for sup_pc_s in sup_pc_s_list]
|
| 158 |
-
|
| 159 |
-
rand_points = torch.cat(rand_points, dim=1)
|
| 160 |
-
target = torch.cat(rand_points_val, dim=1)[...,0]
|
| 161 |
-
target = -target
|
| 162 |
-
|
| 163 |
-
latents, posterior = self.vae_model.encode(
|
| 164 |
-
batch['surface'], sample_posterior=True, need_kl=True)
|
| 165 |
-
latents = self.vae_model.decode(latents)
|
| 166 |
-
logits = self.vae_model.query(latents, rand_points)
|
| 167 |
-
|
| 168 |
-
loss_kl = posterior.kl()
|
| 169 |
-
loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
|
| 170 |
-
|
| 171 |
-
criteria = torch.nn.MSELoss()
|
| 172 |
-
criteria2 = torch.nn.L1Loss()
|
| 173 |
-
loss_logits = criteria(logits, target).mean() + criteria2(logits, target).mean()
|
| 174 |
-
loss = self.loss_cfg.lambda_logits * loss_logits + self.loss_cfg.lambda_kl * loss_kl
|
| 175 |
-
|
| 176 |
-
loss_dict = {
|
| 177 |
-
"loss": loss,
|
| 178 |
-
"loss_logits": loss_logits,
|
| 179 |
-
"loss_kl": loss_kl
|
| 180 |
-
}
|
| 181 |
-
return loss_dict, latents
|
| 182 |
-
|
| 183 |
-
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
| 184 |
-
loss, latents = self.forward(batch)
|
| 185 |
-
split = 'train'
|
| 186 |
-
loss_dict = {
|
| 187 |
-
f"{split}/total_loss": loss["loss"].detach(),
|
| 188 |
-
f"{split}/loss_logits": loss["loss_logits"].detach(),
|
| 189 |
-
f"{split}/loss_kl": loss["loss_kl"].detach(),
|
| 190 |
-
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
| 191 |
-
}
|
| 192 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
| 193 |
-
|
| 194 |
-
return loss
|
| 195 |
-
|
| 196 |
-
def validation_step(self, batch, batch_idx, optimizer_idx=0):
|
| 197 |
-
loss, latents = self.forward(batch)
|
| 198 |
-
split = 'val'
|
| 199 |
-
loss_dict = {
|
| 200 |
-
f"{split}/total_loss": loss["loss"].detach(),
|
| 201 |
-
f"{split}/loss_logits": loss["loss_logits"].detach(),
|
| 202 |
-
f"{split}/loss_kl": loss["loss_kl"].detach(),
|
| 203 |
-
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
| 204 |
-
}
|
| 205 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
| 206 |
-
if self.trainer.global_rank < 2:
|
| 207 |
-
with torch.no_grad():
|
| 208 |
-
save_dir = f"{self.save_root}/gs{self.global_step:010d}_rank{self.trainer.global_rank}"
|
| 209 |
-
if not os.path.exists(save_dir):
|
| 210 |
-
os.makedirs(save_dir)
|
| 211 |
-
uids = batch.get('uid')
|
| 212 |
-
for i, latent in enumerate(latents[:5]):
|
| 213 |
-
mesh, grid_logits = self.vae_model.latents2mesh(
|
| 214 |
-
latent[None],
|
| 215 |
-
output_type='trimesh',
|
| 216 |
-
bounds=1.01,
|
| 217 |
-
mc_level=0.0,
|
| 218 |
-
num_chunks=20000,
|
| 219 |
-
octree_resolution=self.mc_res,
|
| 220 |
-
mc_algo='mc',
|
| 221 |
-
enable_pbar=True
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
mesh = export_to_trimesh(mesh[0])
|
| 225 |
-
|
| 226 |
-
save_path = f"{save_dir}/recon_{os.path.splitext(os.path.basename(uids[i]))[0]}_mc{self.mc_res}.obj"
|
| 227 |
-
mesh.export(save_path)
|
| 228 |
-
|
| 229 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/autoencoders/volume_decoders.py
DELETED
|
@@ -1,440 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
from typing import Union, Tuple, List, Callable
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn as nn
|
| 27 |
-
import torch.nn.functional as F
|
| 28 |
-
from einops import repeat
|
| 29 |
-
from tqdm import tqdm
|
| 30 |
-
|
| 31 |
-
from .attention_blocks import CrossAttentionDecoder
|
| 32 |
-
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
| 33 |
-
from ...utils import logger
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
| 37 |
-
val = input_tensor + alpha
|
| 38 |
-
valid_mask = val > -9000
|
| 39 |
-
|
| 40 |
-
mask = torch.ones_like(val, dtype=torch.int32)
|
| 41 |
-
sign = torch.sign(val.to(torch.float32))
|
| 42 |
-
|
| 43 |
-
# Helper to compute neighbor for a single direction
|
| 44 |
-
def check_neighbor_sign(shift, axis):
|
| 45 |
-
if shift == 0:
|
| 46 |
-
return
|
| 47 |
-
|
| 48 |
-
pad_dims = [0, 0, 0, 0, 0, 0]
|
| 49 |
-
if axis == 0:
|
| 50 |
-
pad_idx = 0 if shift > 0 else 1
|
| 51 |
-
pad_dims[pad_idx] = abs(shift)
|
| 52 |
-
elif axis == 1:
|
| 53 |
-
pad_idx = 2 if shift > 0 else 3
|
| 54 |
-
pad_dims[pad_idx] = abs(shift)
|
| 55 |
-
elif axis == 2:
|
| 56 |
-
pad_idx = 4 if shift > 0 else 5
|
| 57 |
-
pad_dims[pad_idx] = abs(shift)
|
| 58 |
-
|
| 59 |
-
padded = F.pad(val.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate')
|
| 60 |
-
|
| 61 |
-
slice_dims = [slice(None)] * 3
|
| 62 |
-
if axis == 0:
|
| 63 |
-
if shift > 0: slice_dims[0] = slice(shift, None)
|
| 64 |
-
else: slice_dims[0] = slice(None, shift)
|
| 65 |
-
elif axis == 1:
|
| 66 |
-
if shift > 0: slice_dims[1] = slice(shift, None)
|
| 67 |
-
else: slice_dims[1] = slice(None, shift)
|
| 68 |
-
elif axis == 2:
|
| 69 |
-
if shift > 0: slice_dims[2] = slice(shift, None)
|
| 70 |
-
else: slice_dims[2] = slice(None, shift)
|
| 71 |
-
|
| 72 |
-
padded = padded.squeeze(0).squeeze(0)
|
| 73 |
-
neighbor = padded[slice_dims]
|
| 74 |
-
neighbor = torch.where(neighbor > -9000, neighbor, val)
|
| 75 |
-
|
| 76 |
-
# Check sign consistency
|
| 77 |
-
neighbor_sign = torch.sign(neighbor.to(torch.float32))
|
| 78 |
-
return (neighbor_sign == sign)
|
| 79 |
-
|
| 80 |
-
# Iteratively check neighbors and update mask
|
| 81 |
-
# directions: (shift, axis)
|
| 82 |
-
directions = [(1, 0), (-1, 0), (1, 1), (-1, 1), (1, 2), (-1, 2)]
|
| 83 |
-
|
| 84 |
-
for shift, axis in directions:
|
| 85 |
-
is_same = check_neighbor_sign(shift, axis)
|
| 86 |
-
mask = mask & is_same.to(torch.int32)
|
| 87 |
-
|
| 88 |
-
# Invert mask: we want 1 where ANY neighbor has different sign
|
| 89 |
-
mask = (~(mask.bool())).to(torch.int32)
|
| 90 |
-
return mask * valid_mask.to(torch.int32)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def generate_dense_grid_points(
|
| 94 |
-
bbox_min: np.ndarray,
|
| 95 |
-
bbox_max: np.ndarray,
|
| 96 |
-
octree_resolution: int,
|
| 97 |
-
indexing: str = "ij",
|
| 98 |
-
):
|
| 99 |
-
length = bbox_max - bbox_min
|
| 100 |
-
num_cells = octree_resolution
|
| 101 |
-
|
| 102 |
-
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
| 103 |
-
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
| 104 |
-
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
| 105 |
-
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
| 106 |
-
xyz = np.stack((xs, ys, zs), axis=-1)
|
| 107 |
-
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
| 108 |
-
|
| 109 |
-
return xyz, grid_size, length
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class VanillaVolumeDecoder:
|
| 113 |
-
@torch.no_grad()
|
| 114 |
-
def __call__(
|
| 115 |
-
self,
|
| 116 |
-
latents: torch.FloatTensor,
|
| 117 |
-
geo_decoder: Callable,
|
| 118 |
-
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 119 |
-
num_chunks: int = 10000,
|
| 120 |
-
octree_resolution: int = None,
|
| 121 |
-
enable_pbar: bool = True,
|
| 122 |
-
**kwargs,
|
| 123 |
-
):
|
| 124 |
-
device = latents.device
|
| 125 |
-
dtype = latents.dtype
|
| 126 |
-
batch_size = latents.shape[0]
|
| 127 |
-
|
| 128 |
-
# 1. generate query points
|
| 129 |
-
if isinstance(bounds, float):
|
| 130 |
-
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 131 |
-
|
| 132 |
-
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
| 133 |
-
xyz_samples, grid_size, length = generate_dense_grid_points(
|
| 134 |
-
bbox_min=bbox_min,
|
| 135 |
-
bbox_max=bbox_max,
|
| 136 |
-
octree_resolution=octree_resolution,
|
| 137 |
-
indexing="ij"
|
| 138 |
-
)
|
| 139 |
-
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
| 140 |
-
|
| 141 |
-
# 2. latents to 3d volume
|
| 142 |
-
batch_logits = []
|
| 143 |
-
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding",
|
| 144 |
-
disable=not enable_pbar):
|
| 145 |
-
chunk_queries = xyz_samples[start: start + num_chunks, :]
|
| 146 |
-
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
| 147 |
-
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
| 148 |
-
batch_logits.append(logits)
|
| 149 |
-
|
| 150 |
-
grid_logits = torch.cat(batch_logits, dim=1)
|
| 151 |
-
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
| 152 |
-
|
| 153 |
-
return grid_logits
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
class HierarchicalVolumeDecoding:
|
| 157 |
-
@torch.no_grad()
|
| 158 |
-
def __call__(
|
| 159 |
-
self,
|
| 160 |
-
latents: torch.FloatTensor,
|
| 161 |
-
geo_decoder: Callable,
|
| 162 |
-
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 163 |
-
num_chunks: int = 10000,
|
| 164 |
-
mc_level: float = 0.0,
|
| 165 |
-
octree_resolution: int = None,
|
| 166 |
-
min_resolution: int = 63,
|
| 167 |
-
enable_pbar: bool = True,
|
| 168 |
-
**kwargs,
|
| 169 |
-
):
|
| 170 |
-
device = latents.device
|
| 171 |
-
dtype = latents.dtype
|
| 172 |
-
|
| 173 |
-
resolutions = []
|
| 174 |
-
if octree_resolution < min_resolution:
|
| 175 |
-
resolutions.append(octree_resolution)
|
| 176 |
-
while octree_resolution >= min_resolution:
|
| 177 |
-
resolutions.append(octree_resolution)
|
| 178 |
-
octree_resolution = octree_resolution // 2
|
| 179 |
-
resolutions.reverse()
|
| 180 |
-
|
| 181 |
-
# 1. generate query points
|
| 182 |
-
if isinstance(bounds, float):
|
| 183 |
-
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 184 |
-
bbox_min = np.array(bounds[0:3])
|
| 185 |
-
bbox_max = np.array(bounds[3:6])
|
| 186 |
-
bbox_size = bbox_max - bbox_min
|
| 187 |
-
|
| 188 |
-
xyz_samples, grid_size, length = generate_dense_grid_points(
|
| 189 |
-
bbox_min=bbox_min,
|
| 190 |
-
bbox_max=bbox_max,
|
| 191 |
-
octree_resolution=resolutions[0],
|
| 192 |
-
indexing="ij"
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
| 196 |
-
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
| 197 |
-
|
| 198 |
-
grid_size = np.array(grid_size)
|
| 199 |
-
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
| 200 |
-
|
| 201 |
-
# 2. latents to 3d volume
|
| 202 |
-
batch_logits = []
|
| 203 |
-
batch_size = latents.shape[0]
|
| 204 |
-
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
| 205 |
-
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
|
| 206 |
-
queries = xyz_samples[start: start + num_chunks, :]
|
| 207 |
-
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
| 208 |
-
logits = geo_decoder(queries=batch_queries, latents=latents)
|
| 209 |
-
batch_logits.append(logits)
|
| 210 |
-
|
| 211 |
-
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
| 212 |
-
|
| 213 |
-
for octree_depth_now in resolutions[1:]:
|
| 214 |
-
grid_size = np.array([octree_depth_now + 1] * 3)
|
| 215 |
-
resolution = bbox_size / octree_depth_now
|
| 216 |
-
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
| 217 |
-
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
| 218 |
-
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
| 219 |
-
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
| 220 |
-
|
| 221 |
-
if octree_depth_now == resolutions[-1]:
|
| 222 |
-
expand_num = 0
|
| 223 |
-
else:
|
| 224 |
-
expand_num = 1
|
| 225 |
-
for i in range(expand_num):
|
| 226 |
-
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
| 227 |
-
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
| 228 |
-
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
| 229 |
-
for i in range(2 - expand_num):
|
| 230 |
-
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
| 231 |
-
nidx = torch.where(next_index > 0)
|
| 232 |
-
|
| 233 |
-
# Store shape before deleting
|
| 234 |
-
next_index_shape = next_index.shape
|
| 235 |
-
del next_index
|
| 236 |
-
torch.cuda.empty_cache()
|
| 237 |
-
|
| 238 |
-
next_points = torch.stack(nidx, dim=1)
|
| 239 |
-
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
|
| 240 |
-
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
|
| 241 |
-
batch_logits = []
|
| 242 |
-
for start in tqdm(range(0, next_points.shape[0], num_chunks),
|
| 243 |
-
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
|
| 244 |
-
queries = next_points[start: start + num_chunks, :]
|
| 245 |
-
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
| 246 |
-
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
|
| 247 |
-
batch_logits.append(logits)
|
| 248 |
-
|
| 249 |
-
# Delayed allocation of next_logits
|
| 250 |
-
next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
|
| 251 |
-
grid_logits = torch.cat(batch_logits, dim=1)
|
| 252 |
-
next_logits[nidx] = grid_logits[0, ..., 0]
|
| 253 |
-
grid_logits = next_logits.unsqueeze(0)
|
| 254 |
-
grid_logits[grid_logits == -10000.] = float('nan')
|
| 255 |
-
|
| 256 |
-
return grid_logits
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
class FlashVDMVolumeDecoding:
|
| 260 |
-
def __init__(self, topk_mode='mean'):
|
| 261 |
-
if topk_mode not in ['mean', 'merge']:
|
| 262 |
-
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
|
| 263 |
-
|
| 264 |
-
if topk_mode == 'mean':
|
| 265 |
-
self.processor = FlashVDMCrossAttentionProcessor()
|
| 266 |
-
else:
|
| 267 |
-
self.processor = FlashVDMTopMCrossAttentionProcessor()
|
| 268 |
-
|
| 269 |
-
@torch.no_grad()
|
| 270 |
-
def __call__(
|
| 271 |
-
self,
|
| 272 |
-
latents: torch.FloatTensor,
|
| 273 |
-
geo_decoder: CrossAttentionDecoder,
|
| 274 |
-
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 275 |
-
num_chunks: int = 10000,
|
| 276 |
-
mc_level: float = 0.0,
|
| 277 |
-
octree_resolution: int = None,
|
| 278 |
-
min_resolution: int = 63,
|
| 279 |
-
mini_grid_num: int = 4,
|
| 280 |
-
enable_pbar: bool = True,
|
| 281 |
-
**kwargs,
|
| 282 |
-
):
|
| 283 |
-
processor = self.processor
|
| 284 |
-
geo_decoder.set_cross_attention_processor(processor)
|
| 285 |
-
|
| 286 |
-
device = latents.device
|
| 287 |
-
dtype = latents.dtype
|
| 288 |
-
|
| 289 |
-
resolutions = []
|
| 290 |
-
if octree_resolution < min_resolution:
|
| 291 |
-
resolutions.append(octree_resolution)
|
| 292 |
-
while octree_resolution >= min_resolution:
|
| 293 |
-
resolutions.append(octree_resolution)
|
| 294 |
-
octree_resolution = octree_resolution // 2
|
| 295 |
-
resolutions.reverse()
|
| 296 |
-
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
| 297 |
-
for i, resolution in enumerate(resolutions[1:]):
|
| 298 |
-
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
| 299 |
-
|
| 300 |
-
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
|
| 301 |
-
|
| 302 |
-
# 1. generate query points
|
| 303 |
-
if isinstance(bounds, float):
|
| 304 |
-
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 305 |
-
bbox_min = np.array(bounds[0:3])
|
| 306 |
-
bbox_max = np.array(bounds[3:6])
|
| 307 |
-
bbox_size = bbox_max - bbox_min
|
| 308 |
-
|
| 309 |
-
xyz_samples, grid_size, length = generate_dense_grid_points(
|
| 310 |
-
bbox_min=bbox_min,
|
| 311 |
-
bbox_max=bbox_max,
|
| 312 |
-
octree_resolution=resolutions[0],
|
| 313 |
-
indexing="ij"
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
| 317 |
-
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
| 318 |
-
|
| 319 |
-
grid_size = np.array(grid_size)
|
| 320 |
-
|
| 321 |
-
# 2. latents to 3d volume
|
| 322 |
-
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
| 323 |
-
batch_size = latents.shape[0]
|
| 324 |
-
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
| 325 |
-
xyz_samples = xyz_samples.view(
|
| 326 |
-
mini_grid_num, mini_grid_size,
|
| 327 |
-
mini_grid_num, mini_grid_size,
|
| 328 |
-
mini_grid_num, mini_grid_size, 3
|
| 329 |
-
).permute(
|
| 330 |
-
0, 2, 4, 1, 3, 5, 6
|
| 331 |
-
).reshape(
|
| 332 |
-
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
| 333 |
-
)
|
| 334 |
-
batch_logits = []
|
| 335 |
-
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
| 336 |
-
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
|
| 337 |
-
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
|
| 338 |
-
queries = xyz_samples[start: start + num_batchs, :]
|
| 339 |
-
batch = queries.shape[0]
|
| 340 |
-
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
| 341 |
-
processor.topk = True
|
| 342 |
-
|
| 343 |
-
# Chunk queries along dim 1 if too large
|
| 344 |
-
if queries.shape[1] > num_chunks:
|
| 345 |
-
batch_logits_sub = []
|
| 346 |
-
for sub_start in range(0, queries.shape[1], num_chunks):
|
| 347 |
-
sub_queries = queries[:, sub_start: sub_start + num_chunks, :]
|
| 348 |
-
logits = geo_decoder(queries=sub_queries, latents=batch_latents)
|
| 349 |
-
batch_logits_sub.append(logits)
|
| 350 |
-
logits = torch.cat(batch_logits_sub, dim=1)
|
| 351 |
-
else:
|
| 352 |
-
logits = geo_decoder(queries=queries, latents=batch_latents)
|
| 353 |
-
|
| 354 |
-
batch_logits.append(logits)
|
| 355 |
-
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
| 356 |
-
mini_grid_num, mini_grid_num, mini_grid_num,
|
| 357 |
-
mini_grid_size, mini_grid_size,
|
| 358 |
-
mini_grid_size
|
| 359 |
-
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
| 360 |
-
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
| 361 |
-
)
|
| 362 |
-
|
| 363 |
-
for octree_depth_now in resolutions[1:]:
|
| 364 |
-
grid_size = np.array([octree_depth_now + 1] * 3)
|
| 365 |
-
resolution = bbox_size / octree_depth_now
|
| 366 |
-
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
| 367 |
-
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
| 368 |
-
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
| 369 |
-
|
| 370 |
-
if octree_depth_now == resolutions[-1]:
|
| 371 |
-
expand_num = 0
|
| 372 |
-
else:
|
| 373 |
-
expand_num = 1
|
| 374 |
-
for i in range(expand_num):
|
| 375 |
-
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
| 376 |
-
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
| 377 |
-
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
| 378 |
-
|
| 379 |
-
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
| 380 |
-
for i in range(2 - expand_num):
|
| 381 |
-
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
| 382 |
-
nidx = torch.where(next_index > 0)
|
| 383 |
-
|
| 384 |
-
# Store shape before deleting
|
| 385 |
-
next_index_shape = next_index.shape
|
| 386 |
-
del next_index
|
| 387 |
-
torch.cuda.empty_cache()
|
| 388 |
-
|
| 389 |
-
next_points = torch.stack(nidx, dim=1)
|
| 390 |
-
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
| 391 |
-
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
| 392 |
-
|
| 393 |
-
query_grid_num = 6
|
| 394 |
-
min_val = next_points.min(axis=0).values
|
| 395 |
-
max_val = next_points.max(axis=0).values
|
| 396 |
-
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
| 397 |
-
index = torch.floor(vol_queries_index).long()
|
| 398 |
-
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
| 399 |
-
index = index.sort()
|
| 400 |
-
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
| 401 |
-
unique_values = torch.unique(index.values, return_counts=True)
|
| 402 |
-
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
| 403 |
-
input_grid = [[], []]
|
| 404 |
-
logits_grid_list = []
|
| 405 |
-
start_num = 0
|
| 406 |
-
sum_num = 0
|
| 407 |
-
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
| 408 |
-
remaining_count = count
|
| 409 |
-
while remaining_count > 0:
|
| 410 |
-
space_left = num_chunks - sum_num
|
| 411 |
-
# If buffer is full, flush it
|
| 412 |
-
if space_left <= 0:
|
| 413 |
-
processor.topk = input_grid
|
| 414 |
-
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
| 415 |
-
start_num = start_num + sum_num
|
| 416 |
-
logits_grid_list.append(logits_grid)
|
| 417 |
-
input_grid = [[], []]
|
| 418 |
-
sum_num = 0
|
| 419 |
-
space_left = num_chunks
|
| 420 |
-
|
| 421 |
-
take = min(remaining_count, space_left)
|
| 422 |
-
input_grid[0].append(grid_index)
|
| 423 |
-
input_grid[1].append(take)
|
| 424 |
-
sum_num += take
|
| 425 |
-
remaining_count -= take
|
| 426 |
-
if sum_num > 0:
|
| 427 |
-
processor.topk = input_grid
|
| 428 |
-
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
| 429 |
-
logits_grid_list.append(logits_grid)
|
| 430 |
-
logits_grid = torch.cat(logits_grid_list, dim=1)
|
| 431 |
-
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
| 432 |
-
|
| 433 |
-
# Delayed allocation of next_logits
|
| 434 |
-
next_logits = torch.full(next_index_shape, -10000., dtype=dtype, device=device)
|
| 435 |
-
next_logits[nidx] = grid_logits
|
| 436 |
-
grid_logits = next_logits.unsqueeze(0)
|
| 437 |
-
|
| 438 |
-
grid_logits[grid_logits == -10000.] = float('nan')
|
| 439 |
-
|
| 440 |
-
return grid_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/conditioner_mask.py
DELETED
|
@@ -1,337 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Open Source Model Licensed under the Apache License Version 2.0
|
| 9 |
-
# and Other Licenses of the Third-Party Components therein:
|
| 10 |
-
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 11 |
-
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 12 |
-
|
| 13 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 14 |
-
# The below software and/or models in this distribution may have been
|
| 15 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
| 16 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 17 |
-
|
| 18 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 19 |
-
# except for the third-party components listed below.
|
| 20 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 21 |
-
# in the repsective licenses of these third-party components.
|
| 22 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 23 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 24 |
-
# all relevant laws and regulations.
|
| 25 |
-
|
| 26 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 27 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 28 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 29 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 30 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
import torch
|
| 35 |
-
import torch.nn as nn
|
| 36 |
-
from torchvision import transforms
|
| 37 |
-
from transformers import (
|
| 38 |
-
CLIPVisionModelWithProjection,
|
| 39 |
-
CLIPVisionConfig,
|
| 40 |
-
Dinov2Model,
|
| 41 |
-
Dinov2Config,
|
| 42 |
-
)
|
| 43 |
-
from transformers import AutoImageProcessor, AutoModel
|
| 44 |
-
|
| 45 |
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 46 |
-
"""
|
| 47 |
-
embed_dim: output dimension for each position
|
| 48 |
-
pos: a list of positions to be encoded: size (M,)
|
| 49 |
-
out: (M, D)
|
| 50 |
-
"""
|
| 51 |
-
assert embed_dim % 2 == 0
|
| 52 |
-
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 53 |
-
omega /= embed_dim / 2.
|
| 54 |
-
omega = 1. / 10000 ** omega # (D/2,)
|
| 55 |
-
|
| 56 |
-
pos = pos.reshape(-1) # (M,)
|
| 57 |
-
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 58 |
-
|
| 59 |
-
emb_sin = np.sin(out) # (M, D/2)
|
| 60 |
-
emb_cos = np.cos(out) # (M, D/2)
|
| 61 |
-
|
| 62 |
-
return np.concatenate([emb_sin, emb_cos], axis=1)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class ImageEncoder(nn.Module):
|
| 66 |
-
def __init__(
|
| 67 |
-
self,
|
| 68 |
-
version=None,
|
| 69 |
-
config=None,
|
| 70 |
-
use_cls_token=True,
|
| 71 |
-
image_size=224,
|
| 72 |
-
**kwargs,
|
| 73 |
-
):
|
| 74 |
-
super().__init__()
|
| 75 |
-
|
| 76 |
-
if config is None:
|
| 77 |
-
self.model = AutoModel.from_pretrained(version)
|
| 78 |
-
else:
|
| 79 |
-
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config))
|
| 80 |
-
|
| 81 |
-
self.model.eval()
|
| 82 |
-
self.model.requires_grad_(False)
|
| 83 |
-
self.use_cls_token = use_cls_token
|
| 84 |
-
self.size = image_size // 14
|
| 85 |
-
self.num_patches = (image_size // 14) ** 2
|
| 86 |
-
if self.use_cls_token:
|
| 87 |
-
self.num_patches += 1
|
| 88 |
-
|
| 89 |
-
self.transform = transforms.Compose(
|
| 90 |
-
[
|
| 91 |
-
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True),
|
| 92 |
-
transforms.CenterCrop(image_size),
|
| 93 |
-
transforms.Normalize(
|
| 94 |
-
mean=self.mean,
|
| 95 |
-
std=self.std,
|
| 96 |
-
),
|
| 97 |
-
]
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
self.mask_transform = transforms.Compose(
|
| 101 |
-
[
|
| 102 |
-
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST),
|
| 103 |
-
transforms.CenterCrop(image_size),
|
| 104 |
-
]
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
|
| 108 |
-
if value_range is not None:
|
| 109 |
-
low, high = value_range
|
| 110 |
-
image = (image - low) / (high - low)
|
| 111 |
-
|
| 112 |
-
image = image.to(self.model.device, dtype=self.model.dtype)
|
| 113 |
-
inputs = self.transform(image)
|
| 114 |
-
outputs = self.model(inputs)
|
| 115 |
-
|
| 116 |
-
last_hidden_state = outputs.last_hidden_state
|
| 117 |
-
if not self.use_cls_token:
|
| 118 |
-
last_hidden_state = last_hidden_state[:, 1:, :]
|
| 119 |
-
|
| 120 |
-
if mask is not None:
|
| 121 |
-
pool = nn.MaxPool2d(kernel_size=(14, 14), stride=(14, 14))
|
| 122 |
-
|
| 123 |
-
mask = self.mask_transform(mask)
|
| 124 |
-
mask = mask.to(image.device, dtype=image.dtype)
|
| 125 |
-
downsampled_mask = pool(mask)
|
| 126 |
-
flattened_mask = downsampled_mask.view(downsampled_mask.shape[0], -1)
|
| 127 |
-
flattened_mask = flattened_mask.unsqueeze(-1)
|
| 128 |
-
|
| 129 |
-
if self.use_cls_token:
|
| 130 |
-
flattened_mask = torch.cat(
|
| 131 |
-
[torch.ones(flattened_mask.shape[0], 1, 1, device=flattened_mask.device, dtype=flattened_mask.dtype),
|
| 132 |
-
flattened_mask], dim=1)
|
| 133 |
-
|
| 134 |
-
valid_mask = (flattened_mask != -1).float()
|
| 135 |
-
masked_hidden_state = last_hidden_state * valid_mask
|
| 136 |
-
valid_mask_bool = valid_mask.squeeze(-1) > 0
|
| 137 |
-
|
| 138 |
-
valid_counts = valid_mask_bool.sum(dim=1)
|
| 139 |
-
max_valid_tokens = valid_counts.max().item()
|
| 140 |
-
|
| 141 |
-
batch_indices = torch.arange(valid_mask_bool.shape[0], device=valid_mask_bool.device)
|
| 142 |
-
batch_indices = batch_indices.unsqueeze(1).expand(-1, valid_mask_bool.shape[1])
|
| 143 |
-
|
| 144 |
-
flat_batch_indices = batch_indices[valid_mask_bool]
|
| 145 |
-
flat_token_indices = torch.arange(valid_mask_bool.shape[1], device=valid_mask_bool.device)
|
| 146 |
-
flat_token_indices = flat_token_indices.unsqueeze(0).expand(valid_mask_bool.shape[0], -1)
|
| 147 |
-
flat_token_indices = flat_token_indices[valid_mask_bool]
|
| 148 |
-
|
| 149 |
-
valid_tokens = masked_hidden_state[flat_batch_indices, flat_token_indices]
|
| 150 |
-
# Create output tensor with special padding value (-1) instead of zeros
|
| 151 |
-
final_output = torch.full(
|
| 152 |
-
(valid_mask_bool.shape[0], max_valid_tokens, last_hidden_state.shape[-1]),
|
| 153 |
-
-1.0, # Use -1 as padding value to clearly distinguish from valid tokens
|
| 154 |
-
device=last_hidden_state.device, dtype=last_hidden_state.dtype
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
cum_counts = torch.cumsum(valid_counts, dim=0) - valid_counts
|
| 158 |
-
for i in range(valid_mask_bool.shape[0]):
|
| 159 |
-
if valid_counts[i] > 0:
|
| 160 |
-
start_idx = cum_counts[i]
|
| 161 |
-
end_idx = start_idx + valid_counts[i]
|
| 162 |
-
final_output[i, :valid_counts[i]] = valid_tokens[start_idx:end_idx]
|
| 163 |
-
|
| 164 |
-
return final_output
|
| 165 |
-
|
| 166 |
-
return last_hidden_state
|
| 167 |
-
|
| 168 |
-
def unconditional_embedding(self, batch_size, **kwargs):
|
| 169 |
-
device = next(self.model.parameters()).device
|
| 170 |
-
dtype = next(self.model.parameters()).dtype
|
| 171 |
-
|
| 172 |
-
num_tokens = kwargs.get('num_tokens', self.num_patches)
|
| 173 |
-
|
| 174 |
-
zero = torch.zeros(
|
| 175 |
-
batch_size,
|
| 176 |
-
num_tokens,
|
| 177 |
-
self.model.config.hidden_size,
|
| 178 |
-
device=device,
|
| 179 |
-
dtype=dtype,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
return zero
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
class CLIPImageEncoder(ImageEncoder):
|
| 186 |
-
MODEL_CLASS = CLIPVisionModelWithProjection
|
| 187 |
-
MODEL_CONFIG_CLASS = CLIPVisionConfig
|
| 188 |
-
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 189 |
-
std = [0.26862954, 0.26130258, 0.27577711]
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
class DinoImageEncoder(ImageEncoder):
|
| 193 |
-
MODEL_CLASS = Dinov2Model
|
| 194 |
-
MODEL_CONFIG_CLASS = Dinov2Config
|
| 195 |
-
mean = [0.485, 0.456, 0.406]
|
| 196 |
-
std = [0.229, 0.224, 0.225]
|
| 197 |
-
|
| 198 |
-
class DinoImageEncoderMV(DinoImageEncoder):
|
| 199 |
-
def __init__(
|
| 200 |
-
self,
|
| 201 |
-
version=None,
|
| 202 |
-
config=None,
|
| 203 |
-
use_cls_token=True,
|
| 204 |
-
image_size=224,
|
| 205 |
-
view_num=4,
|
| 206 |
-
**kwargs,
|
| 207 |
-
):
|
| 208 |
-
super().__init__(version, config, use_cls_token, image_size, **kwargs)
|
| 209 |
-
self.view_num = view_num
|
| 210 |
-
self.num_patches = self.num_patches
|
| 211 |
-
pos = np.arange(self.view_num, dtype=np.float32)
|
| 212 |
-
view_embedding = torch.from_numpy(
|
| 213 |
-
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
|
| 214 |
-
|
| 215 |
-
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
|
| 216 |
-
self.view_embed = view_embedding.unsqueeze(0)
|
| 217 |
-
|
| 218 |
-
def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
|
| 219 |
-
if value_range is not None:
|
| 220 |
-
low, high = value_range
|
| 221 |
-
image = (image - low) / (high - low)
|
| 222 |
-
|
| 223 |
-
image = image.to(self.model.device, dtype=self.model.dtype)
|
| 224 |
-
|
| 225 |
-
bs, num_views, c, h, w = image.shape
|
| 226 |
-
image = image.view(bs * num_views, c, h, w)
|
| 227 |
-
|
| 228 |
-
inputs = self.transform(image)
|
| 229 |
-
outputs = self.model(inputs)
|
| 230 |
-
|
| 231 |
-
last_hidden_state = outputs.last_hidden_state
|
| 232 |
-
last_hidden_state = last_hidden_state.view(
|
| 233 |
-
bs, num_views, last_hidden_state.shape[-2],
|
| 234 |
-
last_hidden_state.shape[-1]
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
|
| 238 |
-
if view_idxs is not None:
|
| 239 |
-
assert len(view_idxs) == bs
|
| 240 |
-
view_embeddings = []
|
| 241 |
-
for i in range(bs):
|
| 242 |
-
view_idx = view_idxs[i]
|
| 243 |
-
assert num_views == len(view_idx)
|
| 244 |
-
view_embeddings.append(self.view_embed[:, view_idx, ...])
|
| 245 |
-
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
|
| 246 |
-
|
| 247 |
-
if num_views != self.view_num:
|
| 248 |
-
view_embedding = view_embedding[:, :num_views, ...]
|
| 249 |
-
last_hidden_state = last_hidden_state + view_embedding
|
| 250 |
-
last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
|
| 251 |
-
last_hidden_state.shape[-1])
|
| 252 |
-
return last_hidden_state
|
| 253 |
-
|
| 254 |
-
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
| 255 |
-
device = next(self.model.parameters()).device
|
| 256 |
-
dtype = next(self.model.parameters()).dtype
|
| 257 |
-
zero = torch.zeros(
|
| 258 |
-
batch_size,
|
| 259 |
-
self.num_patches * len(view_idxs[0]),
|
| 260 |
-
self.model.config.hidden_size,
|
| 261 |
-
device=device,
|
| 262 |
-
dtype=dtype,
|
| 263 |
-
)
|
| 264 |
-
return zero
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def build_image_encoder(config):
|
| 268 |
-
if config['type'] == 'CLIPImageEncoder':
|
| 269 |
-
return CLIPImageEncoder(**config['kwargs'])
|
| 270 |
-
elif config['type'] == 'DinoImageEncoder':
|
| 271 |
-
return DinoImageEncoder(**config['kwargs'])
|
| 272 |
-
elif config['type'] == 'DinoImageEncoderMV':
|
| 273 |
-
return DinoImageEncoderMV(**config['kwargs'])
|
| 274 |
-
else:
|
| 275 |
-
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
class DualImageEncoder(nn.Module):
|
| 279 |
-
def __init__(
|
| 280 |
-
self,
|
| 281 |
-
main_image_encoder,
|
| 282 |
-
additional_image_encoder,
|
| 283 |
-
):
|
| 284 |
-
super().__init__()
|
| 285 |
-
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
| 286 |
-
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
| 287 |
-
|
| 288 |
-
def forward(self, image, mask=None, **kwargs):
|
| 289 |
-
outputs = {
|
| 290 |
-
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
| 291 |
-
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
|
| 292 |
-
}
|
| 293 |
-
return outputs
|
| 294 |
-
|
| 295 |
-
def unconditional_embedding(self, batch_size, **kwargs):
|
| 296 |
-
outputs = {
|
| 297 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 298 |
-
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 299 |
-
}
|
| 300 |
-
return outputs
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
class SingleImageEncoder(nn.Module):
|
| 304 |
-
def __init__(
|
| 305 |
-
self,
|
| 306 |
-
main_image_encoder,
|
| 307 |
-
drop_ratio=0.1,
|
| 308 |
-
):
|
| 309 |
-
super().__init__()
|
| 310 |
-
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
| 311 |
-
self.drop_ratio = drop_ratio
|
| 312 |
-
# self.disable_drop = disable_drop
|
| 313 |
-
|
| 314 |
-
def forward(self, image, disable_drop=True, mask=None, **kwargs):
|
| 315 |
-
outputs = {
|
| 316 |
-
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
| 317 |
-
}
|
| 318 |
-
|
| 319 |
-
if disable_drop:
|
| 320 |
-
return outputs
|
| 321 |
-
else:
|
| 322 |
-
random_p = torch.rand(len(image), device='cuda')
|
| 323 |
-
remain_bool_tensor = random_p > self.drop_ratio
|
| 324 |
-
outputs['main'] *= remain_bool_tensor.view(-1,1,1)
|
| 325 |
-
return outputs
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
outputs = {
|
| 329 |
-
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
| 330 |
-
}
|
| 331 |
-
return outputs
|
| 332 |
-
|
| 333 |
-
def unconditional_embedding(self, batch_size, **kwargs):
|
| 334 |
-
outputs = {
|
| 335 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 336 |
-
}
|
| 337 |
-
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/denoisers/__init__.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
from .dit_mask import RefineDiT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/denoisers/dit_mask.py
DELETED
|
@@ -1,725 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Open Source Model Licensed under the Apache License Version 2.0
|
| 9 |
-
# and Other Licenses of the Third-Party Components therein:
|
| 10 |
-
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 11 |
-
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 12 |
-
|
| 13 |
-
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 14 |
-
# The below software and/or models in this distribution may have been
|
| 15 |
-
# modified by THL A29 Limited ("Tencent Modifications").
|
| 16 |
-
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 17 |
-
|
| 18 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 19 |
-
# except for the third-party components listed below.
|
| 20 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 21 |
-
# in the repsective licenses of these third-party components.
|
| 22 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 23 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 24 |
-
# all relevant laws and regulations.
|
| 25 |
-
|
| 26 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 27 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 28 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 29 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 30 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 31 |
-
|
| 32 |
-
import os
|
| 33 |
-
import yaml
|
| 34 |
-
import math
|
| 35 |
-
|
| 36 |
-
import numpy as np
|
| 37 |
-
import torch
|
| 38 |
-
import torch.nn as nn
|
| 39 |
-
import torch.nn.functional as F
|
| 40 |
-
from einops import rearrange
|
| 41 |
-
|
| 42 |
-
from .moe_layers import MoEBlock
|
| 43 |
-
from ...utils import logger, synchronize_timer, smart_load_model
|
| 44 |
-
|
| 45 |
-
from flash_attn import flash_attn_varlen_func
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def modulate(x, shift, scale):
|
| 49 |
-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class Timesteps(nn.Module):
|
| 53 |
-
def __init__(self,
|
| 54 |
-
num_channels: int,
|
| 55 |
-
downscale_freq_shift: float = 0.0,
|
| 56 |
-
scale: int = 1,
|
| 57 |
-
max_period: int = 10000
|
| 58 |
-
):
|
| 59 |
-
super().__init__()
|
| 60 |
-
self.num_channels = num_channels
|
| 61 |
-
self.downscale_freq_shift = downscale_freq_shift
|
| 62 |
-
self.scale = scale
|
| 63 |
-
self.max_period = max_period
|
| 64 |
-
|
| 65 |
-
def forward(self, timesteps):
|
| 66 |
-
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 67 |
-
embedding_dim = self.num_channels
|
| 68 |
-
half_dim = embedding_dim // 2
|
| 69 |
-
exponent = -math.log(self.max_period) * torch.arange(
|
| 70 |
-
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 71 |
-
exponent = exponent / (half_dim - self.downscale_freq_shift)
|
| 72 |
-
emb = torch.exp(exponent)
|
| 73 |
-
emb = timesteps[:, None].float() * emb[None, :]
|
| 74 |
-
emb = self.scale * emb
|
| 75 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 76 |
-
if embedding_dim % 2 == 1:
|
| 77 |
-
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 78 |
-
return emb
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class TimestepEmbedder(nn.Module):
|
| 82 |
-
"""
|
| 83 |
-
Embeds scalar timesteps into vector representations.
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
def __init__(self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None):
|
| 87 |
-
super().__init__()
|
| 88 |
-
if out_size is None:
|
| 89 |
-
out_size = hidden_size
|
| 90 |
-
self.mlp = nn.Sequential(
|
| 91 |
-
nn.Linear(hidden_size, frequency_embedding_size, bias=True),
|
| 92 |
-
nn.GELU(),
|
| 93 |
-
nn.Linear(frequency_embedding_size, out_size, bias=True),
|
| 94 |
-
)
|
| 95 |
-
self.frequency_embedding_size = frequency_embedding_size
|
| 96 |
-
|
| 97 |
-
if cond_proj_dim is not None:
|
| 98 |
-
self.cond_proj = nn.Linear(cond_proj_dim, frequency_embedding_size, bias=False)
|
| 99 |
-
|
| 100 |
-
self.time_embed = Timesteps(hidden_size)
|
| 101 |
-
|
| 102 |
-
def forward(self, t, condition):
|
| 103 |
-
|
| 104 |
-
t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
|
| 105 |
-
|
| 106 |
-
# t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
| 107 |
-
if condition is not None:
|
| 108 |
-
t_freq = t_freq + self.cond_proj(condition)
|
| 109 |
-
|
| 110 |
-
t = self.mlp(t_freq)
|
| 111 |
-
t = t.unsqueeze(dim=1)
|
| 112 |
-
return t
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class MLP(nn.Module):
|
| 116 |
-
def __init__(self, *, width: int):
|
| 117 |
-
super().__init__()
|
| 118 |
-
self.width = width
|
| 119 |
-
self.fc1 = nn.Linear(width, width * 4)
|
| 120 |
-
self.fc2 = nn.Linear(width * 4, width)
|
| 121 |
-
self.gelu = nn.GELU()
|
| 122 |
-
|
| 123 |
-
def forward(self, x):
|
| 124 |
-
return self.fc2(self.gelu(self.fc1(x)))
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
class CrossAttention(nn.Module):
|
| 128 |
-
def __init__(
|
| 129 |
-
self,
|
| 130 |
-
qdim,
|
| 131 |
-
kdim,
|
| 132 |
-
num_heads,
|
| 133 |
-
qkv_bias=True,
|
| 134 |
-
qk_norm=False,
|
| 135 |
-
norm_layer=nn.LayerNorm,
|
| 136 |
-
**kwargs,
|
| 137 |
-
):
|
| 138 |
-
super().__init__()
|
| 139 |
-
self.qdim = qdim
|
| 140 |
-
self.kdim = kdim
|
| 141 |
-
self.num_heads = num_heads
|
| 142 |
-
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
| 143 |
-
self.head_dim = self.qdim // num_heads
|
| 144 |
-
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 145 |
-
self.scale = self.head_dim ** -0.5
|
| 146 |
-
|
| 147 |
-
self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
|
| 148 |
-
self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
|
| 149 |
-
self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
|
| 150 |
-
|
| 151 |
-
# TODO: eps should be 1 / 65530 if using fp16
|
| 152 |
-
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 153 |
-
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 154 |
-
self.out_proj = nn.Linear(qdim, qdim, bias=True)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def forward(self, x, y):
|
| 158 |
-
"""
|
| 159 |
-
Parameters
|
| 160 |
-
----------
|
| 161 |
-
x: torch.Tensor
|
| 162 |
-
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
| 163 |
-
y: torch.Tensor
|
| 164 |
-
(batch, seqlen2, hidden_dim2) - may contain padding (marked with -1)
|
| 165 |
-
freqs_cis_img: torch.Tensor
|
| 166 |
-
(batch, hidden_dim // 2), RoPE for image
|
| 167 |
-
"""
|
| 168 |
-
b, s1, c = x.shape # [b, s1, D]
|
| 169 |
-
|
| 170 |
-
# Detect padding tokens: check if all values in the feature dimension are -1
|
| 171 |
-
# y_mask: [b, s2], True for valid tokens, False for padding
|
| 172 |
-
y_mask = (y != -1).any(dim=-1) # [b, s2]
|
| 173 |
-
has_padding = not y_mask.all()
|
| 174 |
-
|
| 175 |
-
_, s2, c = y.shape # [b, s2, 1024]
|
| 176 |
-
q = self.to_q(x)
|
| 177 |
-
k = self.to_k(y)
|
| 178 |
-
v = self.to_v(y)
|
| 179 |
-
|
| 180 |
-
kv = torch.cat((k, v), dim=-1)
|
| 181 |
-
split_size = kv.shape[-1] // self.num_heads // 2
|
| 182 |
-
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
| 183 |
-
k, v = torch.split(kv, split_size, dim=-1)
|
| 184 |
-
|
| 185 |
-
q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
| 186 |
-
k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
| 187 |
-
v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
| 188 |
-
|
| 189 |
-
q = self.q_norm(q)
|
| 190 |
-
k = self.k_norm(k)
|
| 191 |
-
|
| 192 |
-
if has_padding:
|
| 193 |
-
seqlens_k = y_mask.sum(dim=1).int()
|
| 194 |
-
q_flat = q.reshape(-1, self.num_heads, self.head_dim)
|
| 195 |
-
|
| 196 |
-
# For k, v: only keep valid tokens (remove padding)
|
| 197 |
-
# Create indices for valid tokens
|
| 198 |
-
valid_indices = []
|
| 199 |
-
cu_seqlens_k = [0]
|
| 200 |
-
for i in range(b):
|
| 201 |
-
valid_len = seqlens_k[i].item()
|
| 202 |
-
batch_indices = torch.arange(valid_len, device=y.device) + i * s2
|
| 203 |
-
valid_indices.append(batch_indices)
|
| 204 |
-
cu_seqlens_k.append(cu_seqlens_k[-1] + valid_len)
|
| 205 |
-
|
| 206 |
-
valid_indices = torch.cat(valid_indices)
|
| 207 |
-
k_flat = k.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
|
| 208 |
-
v_flat = v.reshape(b * s2, self.num_heads, self.head_dim)[valid_indices] # [total_k, h, d]
|
| 209 |
-
|
| 210 |
-
# Create cumulative sequence lengths
|
| 211 |
-
cu_seqlens_q = torch.arange(0, (b + 1) * s1, s1, dtype=torch.int32, device=x.device)
|
| 212 |
-
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=x.device)
|
| 213 |
-
|
| 214 |
-
# Call flash attention varlen
|
| 215 |
-
q_flat = q_flat.to(torch.bfloat16)
|
| 216 |
-
k_flat = k_flat.to(torch.bfloat16)
|
| 217 |
-
v_flat = v_flat.to(torch.bfloat16)
|
| 218 |
-
|
| 219 |
-
context = flash_attn_varlen_func(
|
| 220 |
-
q_flat, k_flat, v_flat,
|
| 221 |
-
cu_seqlens_q, cu_seqlens_k,
|
| 222 |
-
s1, seqlens_k.max().item(),
|
| 223 |
-
dropout_p=0.0,
|
| 224 |
-
softmax_scale=None,
|
| 225 |
-
causal=False
|
| 226 |
-
)
|
| 227 |
-
context = context.reshape(b, s1, -1)
|
| 228 |
-
else:
|
| 229 |
-
with torch.backends.cuda.sdp_kernel(
|
| 230 |
-
enable_flash=True,
|
| 231 |
-
enable_math=False,
|
| 232 |
-
enable_mem_efficient=True
|
| 233 |
-
):
|
| 234 |
-
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.num_heads), (q, k, v))
|
| 235 |
-
|
| 236 |
-
attn_mask = None
|
| 237 |
-
context = F.scaled_dot_product_attention(
|
| 238 |
-
q, k, v, attn_mask=attn_mask
|
| 239 |
-
).transpose(1, 2).reshape(b, s1, -1)
|
| 240 |
-
|
| 241 |
-
out = self.out_proj(context)
|
| 242 |
-
|
| 243 |
-
return out
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
class Attention(nn.Module):
|
| 247 |
-
"""
|
| 248 |
-
We rename some layer names to align with flash attention
|
| 249 |
-
"""
|
| 250 |
-
|
| 251 |
-
def __init__(
|
| 252 |
-
self,
|
| 253 |
-
dim,
|
| 254 |
-
num_heads,
|
| 255 |
-
qkv_bias=True,
|
| 256 |
-
qk_norm=False,
|
| 257 |
-
norm_layer=nn.LayerNorm,
|
| 258 |
-
):
|
| 259 |
-
super().__init__()
|
| 260 |
-
self.dim = dim
|
| 261 |
-
self.num_heads = num_heads
|
| 262 |
-
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 263 |
-
self.head_dim = self.dim // num_heads
|
| 264 |
-
# This assertion is aligned with flash attention
|
| 265 |
-
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
| 266 |
-
self.scale = self.head_dim ** -0.5
|
| 267 |
-
|
| 268 |
-
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 269 |
-
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
| 270 |
-
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
| 271 |
-
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 272 |
-
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
| 273 |
-
self.out_proj = nn.Linear(dim, dim)
|
| 274 |
-
|
| 275 |
-
# def forward(self, x):
|
| 276 |
-
def forward(self, x, rotary_cos=None, rotary_sin=None):
|
| 277 |
-
B, N, C = x.shape
|
| 278 |
-
|
| 279 |
-
q = self.to_q(x)
|
| 280 |
-
k = self.to_k(x)
|
| 281 |
-
v = self.to_v(x)
|
| 282 |
-
|
| 283 |
-
qkv = torch.cat((q, k, v), dim=-1)
|
| 284 |
-
split_size = qkv.shape[-1] // self.num_heads // 3
|
| 285 |
-
qkv = qkv.view(1, -1, self.num_heads, split_size * 3)
|
| 286 |
-
q, k, v = torch.split(qkv, split_size, dim=-1)
|
| 287 |
-
|
| 288 |
-
q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
| 289 |
-
k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [b, h, s, d]
|
| 290 |
-
v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
| 291 |
-
|
| 292 |
-
q = self.q_norm(q) # [b, h, s, d]
|
| 293 |
-
k = self.k_norm(k) # [b, h, s, d]
|
| 294 |
-
|
| 295 |
-
# ========================= Apply RoPE =========================
|
| 296 |
-
if rotary_cos is not None:
|
| 297 |
-
q = apply_rotary_emb(q, rotary_cos, rotary_sin)
|
| 298 |
-
k = apply_rotary_emb(k, rotary_cos, rotary_sin)
|
| 299 |
-
# ==============================================================
|
| 300 |
-
|
| 301 |
-
with torch.backends.cuda.sdp_kernel(
|
| 302 |
-
enable_flash=True,
|
| 303 |
-
enable_math=False,
|
| 304 |
-
enable_mem_efficient=True
|
| 305 |
-
):
|
| 306 |
-
x = F.scaled_dot_product_attention(q, k, v)
|
| 307 |
-
x = x.transpose(1, 2).reshape(B, N, -1)
|
| 308 |
-
|
| 309 |
-
x = self.out_proj(x)
|
| 310 |
-
return x
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
class DiTBlock(nn.Module):
|
| 314 |
-
def __init__(
|
| 315 |
-
self,
|
| 316 |
-
hidden_size,
|
| 317 |
-
c_emb_size,
|
| 318 |
-
num_heads,
|
| 319 |
-
text_states_dim=1024,
|
| 320 |
-
use_flash_attn=False,
|
| 321 |
-
qk_norm=False,
|
| 322 |
-
norm_layer=nn.LayerNorm,
|
| 323 |
-
qk_norm_layer=nn.RMSNorm,
|
| 324 |
-
init_scale=1.0,
|
| 325 |
-
qkv_bias=True,
|
| 326 |
-
skip_connection=True,
|
| 327 |
-
timested_modulate=False,
|
| 328 |
-
use_moe: bool = False,
|
| 329 |
-
num_experts: int = 8,
|
| 330 |
-
moe_top_k: int = 2,
|
| 331 |
-
**kwargs,
|
| 332 |
-
):
|
| 333 |
-
super().__init__()
|
| 334 |
-
self.use_flash_attn = use_flash_attn
|
| 335 |
-
use_ele_affine = True
|
| 336 |
-
|
| 337 |
-
# ========================= Self-Attention =========================
|
| 338 |
-
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
| 339 |
-
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
|
| 340 |
-
norm_layer=qk_norm_layer)
|
| 341 |
-
|
| 342 |
-
# ========================= FFN =========================
|
| 343 |
-
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
|
| 344 |
-
|
| 345 |
-
# ========================= Add =========================
|
| 346 |
-
# Simply use add like SDXL.
|
| 347 |
-
self.timested_modulate = timested_modulate
|
| 348 |
-
if self.timested_modulate:
|
| 349 |
-
self.default_modulation = nn.Sequential(
|
| 350 |
-
nn.SiLU(),
|
| 351 |
-
nn.Linear(c_emb_size, hidden_size, bias=True)
|
| 352 |
-
)
|
| 353 |
-
|
| 354 |
-
# ========================= Cross-Attention =========================
|
| 355 |
-
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
| 356 |
-
qk_norm=qk_norm, norm_layer=qk_norm_layer, init_scale=init_scale)
|
| 357 |
-
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 358 |
-
|
| 359 |
-
if skip_connection:
|
| 360 |
-
self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 361 |
-
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
|
| 362 |
-
else:
|
| 363 |
-
self.skip_linear = None
|
| 364 |
-
|
| 365 |
-
self.use_moe = use_moe
|
| 366 |
-
if self.use_moe:
|
| 367 |
-
self.moe = MoEBlock(
|
| 368 |
-
hidden_size,
|
| 369 |
-
num_experts=num_experts,
|
| 370 |
-
moe_top_k=moe_top_k,
|
| 371 |
-
dropout=0.0,
|
| 372 |
-
activation_fn="gelu",
|
| 373 |
-
final_dropout=False,
|
| 374 |
-
ff_inner_dim=int(hidden_size * 4.0),
|
| 375 |
-
ff_bias=True,
|
| 376 |
-
)
|
| 377 |
-
else:
|
| 378 |
-
self.mlp = MLP(width=hidden_size)
|
| 379 |
-
|
| 380 |
-
def forward(self, x, c=None, text_states=None, skip_value=None, rotary_cos=None, rotary_sin=None):
|
| 381 |
-
|
| 382 |
-
if self.skip_linear is not None:
|
| 383 |
-
cat = torch.cat([skip_value, x], dim=-1)
|
| 384 |
-
x = self.skip_linear(cat)
|
| 385 |
-
x = self.skip_norm(x)
|
| 386 |
-
|
| 387 |
-
# Self-Attention
|
| 388 |
-
if self.timested_modulate:
|
| 389 |
-
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
| 390 |
-
x = x + shift_msa
|
| 391 |
-
|
| 392 |
-
attn_out = self.attn1(self.norm1(x), rotary_cos=rotary_cos, rotary_sin=rotary_sin)
|
| 393 |
-
|
| 394 |
-
x = x + attn_out
|
| 395 |
-
|
| 396 |
-
# Cross-Attention
|
| 397 |
-
x = x + self.attn2(self.norm2(x), text_states)
|
| 398 |
-
|
| 399 |
-
# FFN Layer
|
| 400 |
-
mlp_inputs = self.norm3(x)
|
| 401 |
-
|
| 402 |
-
if self.use_moe:
|
| 403 |
-
x = x + self.moe(mlp_inputs)
|
| 404 |
-
else:
|
| 405 |
-
x = x + self.mlp(mlp_inputs)
|
| 406 |
-
|
| 407 |
-
return x
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
class AttentionPool(nn.Module):
|
| 411 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 412 |
-
super().__init__()
|
| 413 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
| 414 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 415 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 416 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 417 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 418 |
-
self.num_heads = num_heads
|
| 419 |
-
|
| 420 |
-
def forward(self, x, attention_mask=None):
|
| 421 |
-
x = x.permute(1, 0, 2) # NLC -> LNC
|
| 422 |
-
if attention_mask is not None:
|
| 423 |
-
attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
|
| 424 |
-
global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
|
| 425 |
-
x = torch.cat([global_emb[None,], x], dim=0)
|
| 426 |
-
|
| 427 |
-
else:
|
| 428 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
| 429 |
-
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
| 430 |
-
x, _ = F.multi_head_attention_forward(
|
| 431 |
-
query=x[:1], key=x, value=x,
|
| 432 |
-
embed_dim_to_check=x.shape[-1],
|
| 433 |
-
num_heads=self.num_heads,
|
| 434 |
-
q_proj_weight=self.q_proj.weight,
|
| 435 |
-
k_proj_weight=self.k_proj.weight,
|
| 436 |
-
v_proj_weight=self.v_proj.weight,
|
| 437 |
-
in_proj_weight=None,
|
| 438 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 439 |
-
bias_k=None,
|
| 440 |
-
bias_v=None,
|
| 441 |
-
add_zero_attn=False,
|
| 442 |
-
dropout_p=0,
|
| 443 |
-
out_proj_weight=self.c_proj.weight,
|
| 444 |
-
out_proj_bias=self.c_proj.bias,
|
| 445 |
-
use_separate_proj_weight=True,
|
| 446 |
-
training=self.training,
|
| 447 |
-
need_weights=False
|
| 448 |
-
)
|
| 449 |
-
return x.squeeze(0)
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
class FinalLayer(nn.Module):
|
| 453 |
-
"""
|
| 454 |
-
The final layer of DiT.
|
| 455 |
-
"""
|
| 456 |
-
|
| 457 |
-
def __init__(self, final_hidden_size, out_channels):
|
| 458 |
-
super().__init__()
|
| 459 |
-
self.final_hidden_size = final_hidden_size
|
| 460 |
-
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=True, eps=1e-6)
|
| 461 |
-
self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
|
| 462 |
-
|
| 463 |
-
def forward(self, x):
|
| 464 |
-
x = self.norm_final(x)
|
| 465 |
-
x = x[:, 1:]
|
| 466 |
-
x = self.linear(x)
|
| 467 |
-
return x
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
class RefineDiT(nn.Module):
|
| 471 |
-
|
| 472 |
-
@classmethod
|
| 473 |
-
@synchronize_timer('Refine Model Loading')
|
| 474 |
-
def from_single_file(
|
| 475 |
-
cls,
|
| 476 |
-
ckpt_path,
|
| 477 |
-
config_path,
|
| 478 |
-
device='cuda',
|
| 479 |
-
dtype=torch.float16,
|
| 480 |
-
use_safetensors=None,
|
| 481 |
-
**kwargs,
|
| 482 |
-
):
|
| 483 |
-
# load config
|
| 484 |
-
with open(config_path, 'r') as f:
|
| 485 |
-
config = yaml.safe_load(f)
|
| 486 |
-
|
| 487 |
-
# load ckpt
|
| 488 |
-
if use_safetensors:
|
| 489 |
-
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
| 490 |
-
if not os.path.exists(ckpt_path):
|
| 491 |
-
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
| 492 |
-
|
| 493 |
-
logger.info(f"Loading model from {ckpt_path}")
|
| 494 |
-
if use_safetensors:
|
| 495 |
-
import safetensors.torch
|
| 496 |
-
ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
| 497 |
-
else:
|
| 498 |
-
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
| 499 |
-
|
| 500 |
-
if 'model' in ckpt:
|
| 501 |
-
ckpt = ckpt['model']
|
| 502 |
-
if 'model' in config:
|
| 503 |
-
config = config['model']
|
| 504 |
-
|
| 505 |
-
model_kwargs = config['params']
|
| 506 |
-
model_kwargs.update(kwargs)
|
| 507 |
-
|
| 508 |
-
model = cls(**model_kwargs)
|
| 509 |
-
model.load_state_dict(ckpt)
|
| 510 |
-
model.to(device=device, dtype=dtype)
|
| 511 |
-
return model
|
| 512 |
-
|
| 513 |
-
@classmethod
|
| 514 |
-
def from_pretrained(
|
| 515 |
-
cls,
|
| 516 |
-
model_path,
|
| 517 |
-
device='cuda',
|
| 518 |
-
dtype=torch.float16,
|
| 519 |
-
use_safetensors=False,
|
| 520 |
-
variant='fp16',
|
| 521 |
-
subfolder='hunyuan3d-dit-v2-1',
|
| 522 |
-
**kwargs,
|
| 523 |
-
):
|
| 524 |
-
config_path, ckpt_path = smart_load_model(
|
| 525 |
-
model_path,
|
| 526 |
-
subfolder=subfolder,
|
| 527 |
-
use_safetensors=use_safetensors,
|
| 528 |
-
variant=variant
|
| 529 |
-
)
|
| 530 |
-
|
| 531 |
-
return cls.from_single_file(
|
| 532 |
-
ckpt_path,
|
| 533 |
-
config_path,
|
| 534 |
-
device=device,
|
| 535 |
-
dtype=dtype,
|
| 536 |
-
use_safetensors=use_safetensors,
|
| 537 |
-
**kwargs
|
| 538 |
-
)
|
| 539 |
-
|
| 540 |
-
def __init__(
|
| 541 |
-
self,
|
| 542 |
-
input_size=1024,
|
| 543 |
-
in_channels=4,
|
| 544 |
-
hidden_size=1024,
|
| 545 |
-
context_dim=1024,
|
| 546 |
-
depth=24,
|
| 547 |
-
num_heads=16,
|
| 548 |
-
mlp_ratio=4.0,
|
| 549 |
-
norm_type='layer',
|
| 550 |
-
qk_norm_type='rms',
|
| 551 |
-
qk_norm=False,
|
| 552 |
-
text_len=257,
|
| 553 |
-
guidance_cond_proj_dim=None,
|
| 554 |
-
qkv_bias=True,
|
| 555 |
-
num_moe_layers: int = 6,
|
| 556 |
-
num_experts: int = 8,
|
| 557 |
-
moe_top_k: int = 2,
|
| 558 |
-
voxel_query_res: int = 128,
|
| 559 |
-
**kwargs
|
| 560 |
-
):
|
| 561 |
-
super().__init__()
|
| 562 |
-
self.input_size = input_size
|
| 563 |
-
self.depth = depth
|
| 564 |
-
self.in_channels = in_channels
|
| 565 |
-
self.out_channels = in_channels
|
| 566 |
-
self.num_heads = num_heads
|
| 567 |
-
|
| 568 |
-
self.hidden_size = hidden_size
|
| 569 |
-
self.norm = nn.LayerNorm if norm_type == 'layer' else nn.RMSNorm
|
| 570 |
-
self.qk_norm = nn.RMSNorm if qk_norm_type == 'rms' else nn.LayerNorm
|
| 571 |
-
self.context_dim = context_dim
|
| 572 |
-
self.voxel_query_res = voxel_query_res
|
| 573 |
-
|
| 574 |
-
self.guidance_cond_proj_dim = guidance_cond_proj_dim
|
| 575 |
-
|
| 576 |
-
self.text_len = text_len
|
| 577 |
-
|
| 578 |
-
self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
| 579 |
-
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim)
|
| 580 |
-
|
| 581 |
-
self.blocks = nn.ModuleList([
|
| 582 |
-
DiTBlock(hidden_size=hidden_size,
|
| 583 |
-
c_emb_size=hidden_size,
|
| 584 |
-
num_heads=num_heads,
|
| 585 |
-
mlp_ratio=mlp_ratio,
|
| 586 |
-
text_states_dim=context_dim,
|
| 587 |
-
qk_norm=qk_norm,
|
| 588 |
-
norm_layer=self.norm,
|
| 589 |
-
qk_norm_layer=self.qk_norm,
|
| 590 |
-
skip_connection=layer > depth // 2,
|
| 591 |
-
qkv_bias=qkv_bias,
|
| 592 |
-
use_moe=True if depth - layer <= num_moe_layers else False,
|
| 593 |
-
num_experts=num_experts,
|
| 594 |
-
moe_top_k=moe_top_k
|
| 595 |
-
)
|
| 596 |
-
for layer in range(depth)
|
| 597 |
-
])
|
| 598 |
-
self.depth = depth
|
| 599 |
-
|
| 600 |
-
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
| 601 |
-
|
| 602 |
-
def forward(self, x, t, contexts, **kwargs):
|
| 603 |
-
cond = contexts['main']
|
| 604 |
-
|
| 605 |
-
t = self.t_embedder(t, condition=kwargs.get('guidance_cond'))
|
| 606 |
-
x = self.x_embedder(x)
|
| 607 |
-
c = t
|
| 608 |
-
|
| 609 |
-
##########################################
|
| 610 |
-
head_dim = self.blocks[0].attn1.head_dim
|
| 611 |
-
num_cond_tokens = c.shape[1] if c.dim() == 3 else 1
|
| 612 |
-
|
| 613 |
-
device = x.device
|
| 614 |
-
cond_cos = torch.ones(x.shape[0], num_cond_tokens, head_dim, device=device)
|
| 615 |
-
cond_sin = torch.zeros(x.shape[0], num_cond_tokens, head_dim, device=device)
|
| 616 |
-
|
| 617 |
-
voxel_cond = kwargs.get('voxel_cond')
|
| 618 |
-
# rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d(head_dim, voxel_cond)
|
| 619 |
-
rotary_cos_vox, rotary_sin_vox = precompute_freqs_cis_3d_interpolated(
|
| 620 |
-
head_dim, voxel_cond, current_res=self.voxel_query_res)
|
| 621 |
-
|
| 622 |
-
rotary_cos = torch.cat([cond_cos, rotary_cos_vox], dim=1)
|
| 623 |
-
rotary_sin = torch.cat([cond_sin, rotary_sin_vox], dim=1)
|
| 624 |
-
##########################################
|
| 625 |
-
|
| 626 |
-
x = torch.cat([c, x], dim=1)
|
| 627 |
-
|
| 628 |
-
skip_value_list = []
|
| 629 |
-
for layer, block in enumerate(self.blocks):
|
| 630 |
-
skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
|
| 631 |
-
x = block(x, c, cond, rotary_cos=rotary_cos, rotary_sin=rotary_sin, skip_value=skip_value)
|
| 632 |
-
if layer < self.depth // 2:
|
| 633 |
-
skip_value_list.append(x)
|
| 634 |
-
|
| 635 |
-
x = self.final_layer(x)
|
| 636 |
-
return x
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
def apply_rotary_emb(x, cos, sin):
|
| 640 |
-
"""
|
| 641 |
-
x: [B, H, N, D]
|
| 642 |
-
cos, sin: [B, N, D]
|
| 643 |
-
"""
|
| 644 |
-
|
| 645 |
-
cos = cos.unsqueeze(1)
|
| 646 |
-
sin = sin.unsqueeze(1)
|
| 647 |
-
|
| 648 |
-
def rotate_half(x):
|
| 649 |
-
x1, x2 = x.chunk(2, dim=-1)
|
| 650 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 651 |
-
|
| 652 |
-
return (x * cos) + (rotate_half(x) * sin)
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
def precompute_freqs_cis_3d(dim: int, grid_indices: torch.Tensor, theta: float = 10000.0):
|
| 656 |
-
"""
|
| 657 |
-
grid_indices: [B, N, 3] voxel idx
|
| 658 |
-
"""
|
| 659 |
-
dim_x = dim // 3
|
| 660 |
-
dim_y = dim // 3
|
| 661 |
-
dim_z = dim - dim_x - dim_y
|
| 662 |
-
|
| 663 |
-
device = grid_indices.device
|
| 664 |
-
freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
|
| 665 |
-
freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
|
| 666 |
-
freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
|
| 667 |
-
|
| 668 |
-
x_idx = grid_indices[..., 0].float()
|
| 669 |
-
y_idx = grid_indices[..., 1].float()
|
| 670 |
-
z_idx = grid_indices[..., 2].float()
|
| 671 |
-
|
| 672 |
-
args_x = x_idx.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
|
| 673 |
-
args_y = y_idx.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
|
| 674 |
-
args_z = z_idx.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
|
| 675 |
-
|
| 676 |
-
args = torch.cat([args_x, args_y, args_z], dim=-1)
|
| 677 |
-
args = torch.cat([args, args], dim=-1)
|
| 678 |
-
|
| 679 |
-
return torch.cos(args), torch.sin(args)
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
def precompute_freqs_cis_3d_interpolated(
|
| 683 |
-
dim: int,
|
| 684 |
-
grid_indices: torch.Tensor,
|
| 685 |
-
theta: float = 10000.0,
|
| 686 |
-
trained_res: float = 128.0, # training resolution
|
| 687 |
-
current_res: float = 256.0, # inference resolution
|
| 688 |
-
):
|
| 689 |
-
scale_factor = current_res / trained_res
|
| 690 |
-
|
| 691 |
-
dim_x = dim // 3
|
| 692 |
-
dim_y = dim // 3
|
| 693 |
-
dim_z = dim - dim_x - dim_y
|
| 694 |
-
|
| 695 |
-
device = grid_indices.device
|
| 696 |
-
|
| 697 |
-
freqs_x = 1.0 / (theta ** (torch.arange(0, dim_x, 2, device=device).float() / dim_x))
|
| 698 |
-
freqs_y = 1.0 / (theta ** (torch.arange(0, dim_y, 2, device=device).float() / dim_y))
|
| 699 |
-
freqs_z = 1.0 / (theta ** (torch.arange(0, dim_z, 2, device=device).float() / dim_z))
|
| 700 |
-
|
| 701 |
-
num_freqs_x = dim_x // 2 + (dim_x % 2)
|
| 702 |
-
num_freqs_y = dim_y // 2 + (dim_y % 2)
|
| 703 |
-
target_len = dim // 2
|
| 704 |
-
freqs_x = freqs_x[:num_freqs_x]
|
| 705 |
-
freqs_y = freqs_y[:num_freqs_y]
|
| 706 |
-
freqs_z = freqs_z[:(target_len - len(freqs_x) - len(freqs_y))]
|
| 707 |
-
|
| 708 |
-
input_x = grid_indices[..., 0].float()
|
| 709 |
-
input_y = grid_indices[..., 1].float()
|
| 710 |
-
input_z = grid_indices[..., 2].float()
|
| 711 |
-
|
| 712 |
-
# Apply Scaling
|
| 713 |
-
pos_x = input_x / scale_factor
|
| 714 |
-
pos_y = input_y / scale_factor
|
| 715 |
-
pos_z = input_z / scale_factor
|
| 716 |
-
|
| 717 |
-
# pos * freq
|
| 718 |
-
args_x = pos_x.unsqueeze(-1) * freqs_x.unsqueeze(0).unsqueeze(0)
|
| 719 |
-
args_y = pos_y.unsqueeze(-1) * freqs_y.unsqueeze(0).unsqueeze(0)
|
| 720 |
-
args_z = pos_z.unsqueeze(-1) * freqs_z.unsqueeze(0).unsqueeze(0)
|
| 721 |
-
|
| 722 |
-
args = torch.cat([args_x, args_y, args_z], dim=-1)
|
| 723 |
-
args = torch.cat([args, args], dim=-1)
|
| 724 |
-
|
| 725 |
-
return torch.cos(args), torch.sin(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/denoisers/moe_layers.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
import numpy as np
|
| 18 |
-
import math
|
| 19 |
-
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 20 |
-
|
| 21 |
-
import torch.nn.functional as F
|
| 22 |
-
from diffusers.models.attention import FeedForward
|
| 23 |
-
|
| 24 |
-
class AddAuxiliaryLoss(torch.autograd.Function):
|
| 25 |
-
"""
|
| 26 |
-
The trick function of adding auxiliary (aux) loss,
|
| 27 |
-
which includes the gradient of the aux loss during backpropagation.
|
| 28 |
-
"""
|
| 29 |
-
@staticmethod
|
| 30 |
-
def forward(ctx, x, loss):
|
| 31 |
-
assert loss.numel() == 1
|
| 32 |
-
ctx.dtype = loss.dtype
|
| 33 |
-
ctx.required_aux_loss = loss.requires_grad
|
| 34 |
-
return x
|
| 35 |
-
|
| 36 |
-
@staticmethod
|
| 37 |
-
def backward(ctx, grad_output):
|
| 38 |
-
grad_loss = None
|
| 39 |
-
if ctx.required_aux_loss:
|
| 40 |
-
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
| 41 |
-
return grad_output, grad_loss
|
| 42 |
-
|
| 43 |
-
class MoEGate(nn.Module):
|
| 44 |
-
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01):
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.top_k = num_experts_per_tok
|
| 47 |
-
self.n_routed_experts = num_experts
|
| 48 |
-
|
| 49 |
-
self.scoring_func = 'softmax'
|
| 50 |
-
self.alpha = aux_loss_alpha
|
| 51 |
-
self.seq_aux = False
|
| 52 |
-
|
| 53 |
-
# topk selection algorithm
|
| 54 |
-
self.norm_topk_prob = False
|
| 55 |
-
self.gating_dim = embed_dim
|
| 56 |
-
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 57 |
-
self.reset_parameters()
|
| 58 |
-
|
| 59 |
-
def reset_parameters(self) -> None:
|
| 60 |
-
import torch.nn.init as init
|
| 61 |
-
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 62 |
-
|
| 63 |
-
def forward(self, hidden_states):
|
| 64 |
-
bsz, seq_len, h = hidden_states.shape
|
| 65 |
-
# print(bsz, seq_len, h)
|
| 66 |
-
### compute gating score
|
| 67 |
-
hidden_states = hidden_states.view(-1, h)
|
| 68 |
-
logits = F.linear(hidden_states, self.weight, None)
|
| 69 |
-
if self.scoring_func == 'softmax':
|
| 70 |
-
scores = logits.softmax(dim=-1)
|
| 71 |
-
else:
|
| 72 |
-
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
| 73 |
-
|
| 74 |
-
### select top-k experts
|
| 75 |
-
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
| 76 |
-
|
| 77 |
-
### norm gate to sum 1
|
| 78 |
-
if self.top_k > 1 and self.norm_topk_prob:
|
| 79 |
-
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 80 |
-
topk_weight = topk_weight / denominator
|
| 81 |
-
|
| 82 |
-
### expert-level computation auxiliary loss
|
| 83 |
-
if self.training and self.alpha > 0.0:
|
| 84 |
-
scores_for_aux = scores
|
| 85 |
-
aux_topk = self.top_k
|
| 86 |
-
# always compute aux loss based on the naive greedy topk method
|
| 87 |
-
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
| 88 |
-
if self.seq_aux:
|
| 89 |
-
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
| 90 |
-
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
| 91 |
-
ce.scatter_add_(
|
| 92 |
-
1,
|
| 93 |
-
topk_idx_for_aux_loss,
|
| 94 |
-
torch.ones(
|
| 95 |
-
bsz, seq_len * aux_topk,
|
| 96 |
-
device=hidden_states.device
|
| 97 |
-
)
|
| 98 |
-
).div_(seq_len * aux_topk / self.n_routed_experts)
|
| 99 |
-
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean()
|
| 100 |
-
aux_loss = aux_loss * self.alpha
|
| 101 |
-
else:
|
| 102 |
-
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1),
|
| 103 |
-
num_classes=self.n_routed_experts)
|
| 104 |
-
ce = mask_ce.float().mean(0)
|
| 105 |
-
Pi = scores_for_aux.mean(0)
|
| 106 |
-
fi = ce * self.n_routed_experts
|
| 107 |
-
aux_loss = (Pi * fi).sum() * self.alpha
|
| 108 |
-
else:
|
| 109 |
-
aux_loss = None
|
| 110 |
-
return topk_idx, topk_weight, aux_loss
|
| 111 |
-
|
| 112 |
-
class MoEBlock(nn.Module):
|
| 113 |
-
def __init__(self, dim, num_experts=8, moe_top_k=2,
|
| 114 |
-
activation_fn = "gelu", dropout=0.0, final_dropout = False,
|
| 115 |
-
ff_inner_dim = None, ff_bias = True):
|
| 116 |
-
super().__init__()
|
| 117 |
-
self.moe_top_k = moe_top_k
|
| 118 |
-
self.experts = nn.ModuleList([
|
| 119 |
-
FeedForward(dim,dropout=dropout,
|
| 120 |
-
activation_fn=activation_fn,
|
| 121 |
-
final_dropout=final_dropout,
|
| 122 |
-
inner_dim=ff_inner_dim,
|
| 123 |
-
bias=ff_bias)
|
| 124 |
-
for i in range(num_experts)])
|
| 125 |
-
self.gate = MoEGate(embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k)
|
| 126 |
-
|
| 127 |
-
self.shared_experts = FeedForward(dim,dropout=dropout, activation_fn=activation_fn,
|
| 128 |
-
final_dropout=final_dropout, inner_dim=ff_inner_dim,
|
| 129 |
-
bias=ff_bias)
|
| 130 |
-
|
| 131 |
-
def initialize_weight(self):
|
| 132 |
-
pass
|
| 133 |
-
|
| 134 |
-
def forward(self, hidden_states):
|
| 135 |
-
identity = hidden_states
|
| 136 |
-
orig_shape = hidden_states.shape
|
| 137 |
-
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 138 |
-
|
| 139 |
-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 140 |
-
flat_topk_idx = topk_idx.view(-1)
|
| 141 |
-
if self.training:
|
| 142 |
-
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
|
| 143 |
-
y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
|
| 144 |
-
for i, expert in enumerate(self.experts):
|
| 145 |
-
tmp = expert(hidden_states[flat_topk_idx == i])
|
| 146 |
-
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
| 147 |
-
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 148 |
-
y = y.view(*orig_shape)
|
| 149 |
-
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 150 |
-
else:
|
| 151 |
-
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
| 152 |
-
y = y + self.shared_experts(identity)
|
| 153 |
-
return y
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
@torch.no_grad()
|
| 157 |
-
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
| 158 |
-
expert_cache = torch.zeros_like(x)
|
| 159 |
-
idxs = flat_expert_indices.argsort()
|
| 160 |
-
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
| 161 |
-
token_idxs = idxs // self.moe_top_k
|
| 162 |
-
for i, end_idx in enumerate(tokens_per_expert):
|
| 163 |
-
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
| 164 |
-
if start_idx == end_idx:
|
| 165 |
-
continue
|
| 166 |
-
expert = self.experts[i]
|
| 167 |
-
exp_token_idx = token_idxs[start_idx:end_idx]
|
| 168 |
-
expert_tokens = x[exp_token_idx]
|
| 169 |
-
expert_out = expert(expert_tokens)
|
| 170 |
-
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
| 171 |
-
|
| 172 |
-
# for fp16 and other dtype
|
| 173 |
-
expert_cache = expert_cache.to(expert_out.dtype)
|
| 174 |
-
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
|
| 175 |
-
expert_out,
|
| 176 |
-
reduce='sum')
|
| 177 |
-
return expert_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/flow_matching_dit_trainer.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
from contextlib import contextmanager
|
| 4 |
-
from typing import List, Tuple, Optional, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
from torch.optim import lr_scheduler
|
| 9 |
-
import pytorch_lightning as pl
|
| 10 |
-
from pytorch_lightning.utilities import rank_zero_info
|
| 11 |
-
from pytorch_lightning.utilities import rank_zero_only
|
| 12 |
-
from ultrashape.pipelines import export_to_trimesh
|
| 13 |
-
|
| 14 |
-
from ...utils.ema import LitEma
|
| 15 |
-
from ...utils.misc import instantiate_from_config, instantiate_non_trainable_model, instantiate_vae_model, instantiate_vae_model_local
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class Diffuser(pl.LightningModule):
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
*,
|
| 22 |
-
vae_config,
|
| 23 |
-
cond_config,
|
| 24 |
-
dit_cfg,
|
| 25 |
-
scheduler_cfg,
|
| 26 |
-
optimizer_cfg,
|
| 27 |
-
pipeline_cfg=None,
|
| 28 |
-
image_processor_cfg=None,
|
| 29 |
-
lora_config=None,
|
| 30 |
-
ema_config=None,
|
| 31 |
-
scale_by_std: bool = False,
|
| 32 |
-
z_scale_factor: float = 1.0,
|
| 33 |
-
ckpt_path: Optional[str] = None,
|
| 34 |
-
ignore_keys: Union[Tuple[str], List[str]] = (),
|
| 35 |
-
torch_compile: bool = False,
|
| 36 |
-
):
|
| 37 |
-
super().__init__()
|
| 38 |
-
|
| 39 |
-
# ========= init optimizer config ========= #
|
| 40 |
-
self.optimizer_cfg = optimizer_cfg
|
| 41 |
-
|
| 42 |
-
# ========= init diffusion scheduler ========= #
|
| 43 |
-
self.scheduler_cfg = scheduler_cfg
|
| 44 |
-
self.sampler = None
|
| 45 |
-
if 'transport' in scheduler_cfg:
|
| 46 |
-
self.transport = instantiate_from_config(scheduler_cfg.transport)
|
| 47 |
-
self.sampler = instantiate_from_config(scheduler_cfg.sampler, transport=self.transport)
|
| 48 |
-
self.sample_fn = self.sampler.sample_ode(**scheduler_cfg.sampler.ode_params)
|
| 49 |
-
|
| 50 |
-
# ========= init the model ========= #
|
| 51 |
-
self.dit_cfg = dit_cfg
|
| 52 |
-
self.model = instantiate_from_config(dit_cfg, device=None, dtype=None)
|
| 53 |
-
|
| 54 |
-
self.cond_stage_model = instantiate_from_config(cond_config)
|
| 55 |
-
|
| 56 |
-
self.ckpt_path = ckpt_path
|
| 57 |
-
if ckpt_path is not None:
|
| 58 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 59 |
-
|
| 60 |
-
# ========= config lora model ========= #
|
| 61 |
-
if lora_config is not None:
|
| 62 |
-
from peft import LoraConfig, get_peft_model
|
| 63 |
-
loraconfig = LoraConfig(
|
| 64 |
-
r=lora_config.rank,
|
| 65 |
-
lora_alpha=lora_config.rank,
|
| 66 |
-
target_modules=lora_config.get('target_modules')
|
| 67 |
-
)
|
| 68 |
-
self.model = get_peft_model(self.model, loraconfig)
|
| 69 |
-
|
| 70 |
-
# ========= config ema model ========= #
|
| 71 |
-
self.ema_config = ema_config
|
| 72 |
-
if self.ema_config is not None:
|
| 73 |
-
if self.ema_config.ema_model == 'DSEma':
|
| 74 |
-
# from michelangelo.models.modules.ema_deepspeed import DSEma
|
| 75 |
-
from ..utils.ema_deepspeed import DSEma
|
| 76 |
-
self.model_ema = DSEma(self.model, decay=self.ema_config.ema_decay)
|
| 77 |
-
else:
|
| 78 |
-
self.model_ema = LitEma(self.model, decay=self.ema_config.ema_decay)
|
| 79 |
-
#do not initilize EMA weight from ckpt path, since I need to change moe layers
|
| 80 |
-
if ckpt_path is not None:
|
| 81 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 82 |
-
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 83 |
-
|
| 84 |
-
# ========= init vae at last to prevent it is overridden by loaded ckpt ========= #
|
| 85 |
-
self.first_stage_model = instantiate_vae_model_local(vae_config)
|
| 86 |
-
self.first_stage_model.enable_flashvdm_decoder()
|
| 87 |
-
|
| 88 |
-
self.scale_by_std = scale_by_std
|
| 89 |
-
if scale_by_std:
|
| 90 |
-
self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
|
| 91 |
-
else:
|
| 92 |
-
self.z_scale_factor = z_scale_factor
|
| 93 |
-
|
| 94 |
-
# ========= init pipeline for inference ========= #
|
| 95 |
-
self.image_processor_cfg = image_processor_cfg
|
| 96 |
-
self.image_processor = None
|
| 97 |
-
if self.image_processor_cfg is not None:
|
| 98 |
-
self.image_processor = instantiate_from_config(self.image_processor_cfg)
|
| 99 |
-
self.pipeline_cfg = pipeline_cfg
|
| 100 |
-
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 101 |
-
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
| 102 |
-
self.pipeline = instantiate_from_config(
|
| 103 |
-
pipeline_cfg,
|
| 104 |
-
vae=self.first_stage_model,
|
| 105 |
-
model=self.model,
|
| 106 |
-
scheduler=scheduler,
|
| 107 |
-
conditioner=self.cond_stage_model,
|
| 108 |
-
image_processor=self.image_processor,
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
# ========= torch compile to accelerate ========= #
|
| 112 |
-
self.torch_compile = torch_compile
|
| 113 |
-
if self.torch_compile:
|
| 114 |
-
torch.nn.Module.compile(self.model)
|
| 115 |
-
torch.nn.Module.compile(self.first_stage_model)
|
| 116 |
-
torch.nn.Module.compile(self.cond_stage_model)
|
| 117 |
-
print(f'*' * 100)
|
| 118 |
-
print(f'Compile model for acceleration')
|
| 119 |
-
print(f'*' * 100)
|
| 120 |
-
|
| 121 |
-
@contextmanager
|
| 122 |
-
def ema_scope(self, context=None):
|
| 123 |
-
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
| 124 |
-
self.model_ema.store(self.model)
|
| 125 |
-
self.model_ema.copy_to(self.model)
|
| 126 |
-
if context is not None:
|
| 127 |
-
print(f"{context}: Switched to EMA weights")
|
| 128 |
-
try:
|
| 129 |
-
yield None
|
| 130 |
-
finally:
|
| 131 |
-
if self.ema_config is not None and self.ema_config.get('ema_inference', False):
|
| 132 |
-
self.model_ema.restore(self.model)
|
| 133 |
-
if context is not None:
|
| 134 |
-
print(f"{context}: Restored training weights")
|
| 135 |
-
|
| 136 |
-
def init_from_ckpt(self, path, ignore_keys=()):
|
| 137 |
-
ckpt = torch.load(path, map_location="cpu")
|
| 138 |
-
if 'state_dict' not in ckpt:
|
| 139 |
-
# deepspeed ckpt
|
| 140 |
-
state_dict = {}
|
| 141 |
-
for k in ckpt.keys():
|
| 142 |
-
new_k = k.replace('_forward_module.', '')
|
| 143 |
-
state_dict[new_k] = ckpt[k]
|
| 144 |
-
else:
|
| 145 |
-
state_dict = ckpt["state_dict"]
|
| 146 |
-
|
| 147 |
-
keys = list(state_dict.keys())
|
| 148 |
-
for k in keys:
|
| 149 |
-
for ik in ignore_keys:
|
| 150 |
-
if ik in k:
|
| 151 |
-
print("Deleting key {} from state_dict.".format(k))
|
| 152 |
-
del state_dict[k]
|
| 153 |
-
|
| 154 |
-
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
| 155 |
-
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 156 |
-
if len(missing) > 0:
|
| 157 |
-
print(f"Missing Keys: {missing}")
|
| 158 |
-
print(f"Unexpected Keys: {unexpected}")
|
| 159 |
-
|
| 160 |
-
def on_load_checkpoint(self, checkpoint):
|
| 161 |
-
"""
|
| 162 |
-
The pt_model is trained separately, so we already have access to its
|
| 163 |
-
checkpoint and load it separately with `self.set_pt_model`.
|
| 164 |
-
|
| 165 |
-
However, the PL Trainer is strict about
|
| 166 |
-
checkpoint loading (not configurable), so it expects the loaded state_dict
|
| 167 |
-
to match exactly the keys in the model state_dict.
|
| 168 |
-
|
| 169 |
-
So, when loading the checkpoint, before matching keys, we add all pt_model keys
|
| 170 |
-
from self.state_dict() to the checkpoint state dict, so that they match
|
| 171 |
-
"""
|
| 172 |
-
for key in self.state_dict().keys():
|
| 173 |
-
if key.startswith("model_ema") and key not in checkpoint["state_dict"]:
|
| 174 |
-
checkpoint["state_dict"][key] = self.state_dict()[key]
|
| 175 |
-
|
| 176 |
-
def configure_optimizers(self) -> Tuple[List, List]:
|
| 177 |
-
lr = self.learning_rate
|
| 178 |
-
|
| 179 |
-
params_list = []
|
| 180 |
-
trainable_parameters = list(self.model.parameters())
|
| 181 |
-
params_list.append({'params': trainable_parameters, 'lr': lr})
|
| 182 |
-
|
| 183 |
-
no_decay = ['bias', 'norm.weight', 'norm.bias', 'norm1.weight', 'norm1.bias', 'norm2.weight', 'norm2.bias']
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
if self.optimizer_cfg.get('train_image_encoder', False):
|
| 187 |
-
image_encoder_parameters = list(self.cond_stage_model.named_parameters())
|
| 188 |
-
image_encoder_parameters_decay = [param for name, param in image_encoder_parameters if
|
| 189 |
-
not any((no_decay_name in name) for no_decay_name in no_decay)]
|
| 190 |
-
image_encoder_parameters_nodecay = [param for name, param in image_encoder_parameters if
|
| 191 |
-
any((no_decay_name in name) for no_decay_name in no_decay)]
|
| 192 |
-
# filter trainable params
|
| 193 |
-
image_encoder_parameters_decay = [param for param in image_encoder_parameters_decay if
|
| 194 |
-
param.requires_grad]
|
| 195 |
-
image_encoder_parameters_nodecay = [param for param in image_encoder_parameters_nodecay if
|
| 196 |
-
param.requires_grad]
|
| 197 |
-
|
| 198 |
-
print(f"Image Encoder Params: {len(image_encoder_parameters_decay)} decay, ")
|
| 199 |
-
print(f"Image Encoder Params: {len(image_encoder_parameters_nodecay)} nodecay, ")
|
| 200 |
-
|
| 201 |
-
image_encoder_lr = self.optimizer_cfg['image_encoder_lr']
|
| 202 |
-
image_encoder_lr_multiply = self.optimizer_cfg.get('image_encoder_lr_multiply', 1.0)
|
| 203 |
-
image_encoder_lr = image_encoder_lr if image_encoder_lr is not None else lr * image_encoder_lr_multiply
|
| 204 |
-
params_list.append(
|
| 205 |
-
{'params': image_encoder_parameters_decay, 'lr': image_encoder_lr,
|
| 206 |
-
'weight_decay': 0.05})
|
| 207 |
-
params_list.append(
|
| 208 |
-
{'params': image_encoder_parameters_nodecay, 'lr': image_encoder_lr,
|
| 209 |
-
'weight_decay': 0.})
|
| 210 |
-
|
| 211 |
-
optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=params_list, lr=lr)
|
| 212 |
-
if hasattr(self.optimizer_cfg, 'scheduler'):
|
| 213 |
-
scheduler_func = instantiate_from_config(
|
| 214 |
-
self.optimizer_cfg.scheduler,
|
| 215 |
-
max_decay_steps=self.trainer.max_steps,
|
| 216 |
-
lr_max=lr
|
| 217 |
-
)
|
| 218 |
-
scheduler = {
|
| 219 |
-
"scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
|
| 220 |
-
"interval": "step",
|
| 221 |
-
"frequency": 1
|
| 222 |
-
}
|
| 223 |
-
schedulers = [scheduler]
|
| 224 |
-
else:
|
| 225 |
-
schedulers = []
|
| 226 |
-
optimizers = [optimizer]
|
| 227 |
-
|
| 228 |
-
return optimizers, schedulers
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
def on_train_batch_end(self, *args, **kwargs):
|
| 232 |
-
if self.ema_config is not None:
|
| 233 |
-
self.model_ema(self.model)
|
| 234 |
-
|
| 235 |
-
def on_train_epoch_start(self) -> None:
|
| 236 |
-
pl.seed_everything(self.trainer.global_rank)
|
| 237 |
-
|
| 238 |
-
def forward(self, batch, disable_drop):
|
| 239 |
-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): #float32 for text
|
| 240 |
-
contexts = self.cond_stage_model(image=batch.get('image'), text=batch.get('text'), mask=batch.get('mask'), disable_drop=disable_drop)
|
| 241 |
-
|
| 242 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 243 |
-
with torch.no_grad():
|
| 244 |
-
latents, voxel_idx = self.first_stage_model.encode(batch["surface"], sample_posterior=True, need_voxel=True)
|
| 245 |
-
latents = self.z_scale_factor * latents
|
| 246 |
-
# print(latents.shape)
|
| 247 |
-
|
| 248 |
-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 249 |
-
loss = self.transport.training_losses(self.model, latents,
|
| 250 |
-
dict(contexts=contexts, voxel_cond=voxel_idx))["loss"].mean()
|
| 251 |
-
|
| 252 |
-
return loss
|
| 253 |
-
|
| 254 |
-
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
| 255 |
-
loss = self.forward(batch, disable_drop=False)
|
| 256 |
-
split = 'train'
|
| 257 |
-
loss_dict = {
|
| 258 |
-
f"{split}/total_loss": loss.detach(),
|
| 259 |
-
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
| 260 |
-
}
|
| 261 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
| 262 |
-
|
| 263 |
-
return loss
|
| 264 |
-
|
| 265 |
-
def validation_step(self, batch, batch_idx, optimizer_idx=0):
|
| 266 |
-
loss = self.forward(batch, disable_drop=True)
|
| 267 |
-
split = 'val'
|
| 268 |
-
loss_dict = {
|
| 269 |
-
f"{split}/total_loss": loss.detach(),
|
| 270 |
-
f"{split}/lr_abs": self.optimizers().param_groups[0]['lr'],
|
| 271 |
-
}
|
| 272 |
-
self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
|
| 273 |
-
|
| 274 |
-
return loss
|
| 275 |
-
|
| 276 |
-
@torch.no_grad()
|
| 277 |
-
def sample(self, batch, output_type='trimesh', **kwargs):
|
| 278 |
-
self.cond_stage_model.disable_drop = True
|
| 279 |
-
|
| 280 |
-
generator = torch.Generator().manual_seed(0)
|
| 281 |
-
|
| 282 |
-
with self.ema_scope("Sample"):
|
| 283 |
-
with torch.amp.autocast(device_type='cuda'):
|
| 284 |
-
try:
|
| 285 |
-
self.pipeline.device = self.device
|
| 286 |
-
self.pipeline.dtype = self.dtype
|
| 287 |
-
print("### USING PIPELINE ###")
|
| 288 |
-
print(f'device: {self.device} dtype : {self.dtype}')
|
| 289 |
-
additional_params = {'output_type':output_type}
|
| 290 |
-
|
| 291 |
-
image = batch.get("image", None)
|
| 292 |
-
mask = batch.get('mask', None)
|
| 293 |
-
|
| 294 |
-
outputs = self.pipeline(image=image,
|
| 295 |
-
mask=mask,
|
| 296 |
-
generator=generator,
|
| 297 |
-
box_v=1.0,
|
| 298 |
-
mc_level=0.0,
|
| 299 |
-
octree_resolution=1024,
|
| 300 |
-
**additional_params)
|
| 301 |
-
|
| 302 |
-
except Exception as e:
|
| 303 |
-
import traceback
|
| 304 |
-
traceback.print_exc()
|
| 305 |
-
print(f"Unexpected {e=}, {type(e)=}")
|
| 306 |
-
with open("error.txt", "a") as f:
|
| 307 |
-
f.write(str(e))
|
| 308 |
-
f.write(traceback.format_exc())
|
| 309 |
-
f.write("\n")
|
| 310 |
-
outputs = [None]
|
| 311 |
-
|
| 312 |
-
self.cond_stage_model.disable_drop = False
|
| 313 |
-
return [outputs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/transport/__init__.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
-
# which is licensed under the MIT License.
|
| 3 |
-
#
|
| 4 |
-
# MIT License
|
| 5 |
-
#
|
| 6 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
-
#
|
| 8 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
-
# in the Software without restriction, including without limitation the rights
|
| 11 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
-
# furnished to do so, subject to the following conditions:
|
| 14 |
-
#
|
| 15 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
-
# copies or substantial portions of the Software.
|
| 17 |
-
#
|
| 18 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
-
# SOFTWARE.
|
| 25 |
-
|
| 26 |
-
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def create_transport(
|
| 30 |
-
path_type='Linear',
|
| 31 |
-
prediction="velocity",
|
| 32 |
-
loss_weight=None,
|
| 33 |
-
train_eps=None,
|
| 34 |
-
sample_eps=None,
|
| 35 |
-
train_sample_type="uniform",
|
| 36 |
-
mean = 0.0,
|
| 37 |
-
std = 1.0,
|
| 38 |
-
shift_scale = 1.0,
|
| 39 |
-
):
|
| 40 |
-
"""function for creating Transport object
|
| 41 |
-
**Note**: model prediction defaults to velocity
|
| 42 |
-
Args:
|
| 43 |
-
- path_type: type of path to use; default to linear
|
| 44 |
-
- learn_score: set model prediction to score
|
| 45 |
-
- learn_noise: set model prediction to noise
|
| 46 |
-
- velocity_weighted: weight loss by velocity weight
|
| 47 |
-
- likelihood_weighted: weight loss by likelihood weight
|
| 48 |
-
- train_eps: small epsilon for avoiding instability during training
|
| 49 |
-
- sample_eps: small epsilon for avoiding instability during sampling
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
if prediction == "noise":
|
| 53 |
-
model_type = ModelType.NOISE
|
| 54 |
-
elif prediction == "score":
|
| 55 |
-
model_type = ModelType.SCORE
|
| 56 |
-
else:
|
| 57 |
-
model_type = ModelType.VELOCITY
|
| 58 |
-
|
| 59 |
-
if loss_weight == "velocity":
|
| 60 |
-
loss_type = WeightType.VELOCITY
|
| 61 |
-
elif loss_weight == "likelihood":
|
| 62 |
-
loss_type = WeightType.LIKELIHOOD
|
| 63 |
-
else:
|
| 64 |
-
loss_type = WeightType.NONE
|
| 65 |
-
|
| 66 |
-
path_choice = {
|
| 67 |
-
"Linear": PathType.LINEAR,
|
| 68 |
-
"GVP": PathType.GVP,
|
| 69 |
-
"VP": PathType.VP,
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
path_type = path_choice[path_type]
|
| 73 |
-
|
| 74 |
-
if (path_type in [PathType.VP]):
|
| 75 |
-
train_eps = 1e-5 if train_eps is None else train_eps
|
| 76 |
-
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 77 |
-
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
| 78 |
-
train_eps = 1e-3 if train_eps is None else train_eps
|
| 79 |
-
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 80 |
-
else: # velocity & [GVP, LINEAR] is stable everywhere
|
| 81 |
-
train_eps = 0
|
| 82 |
-
sample_eps = 0
|
| 83 |
-
|
| 84 |
-
# create flow state
|
| 85 |
-
state = Transport(
|
| 86 |
-
model_type=model_type,
|
| 87 |
-
path_type=path_type,
|
| 88 |
-
loss_type=loss_type,
|
| 89 |
-
train_eps=train_eps,
|
| 90 |
-
sample_eps=sample_eps,
|
| 91 |
-
train_sample_type=train_sample_type,
|
| 92 |
-
mean=mean,
|
| 93 |
-
std=std,
|
| 94 |
-
shift_scale =shift_scale,
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/transport/integrators.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
-
# which is licensed under the MIT License.
|
| 3 |
-
#
|
| 4 |
-
# MIT License
|
| 5 |
-
#
|
| 6 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
-
#
|
| 8 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
-
# in the Software without restriction, including without limitation the rights
|
| 11 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
-
# furnished to do so, subject to the following conditions:
|
| 14 |
-
#
|
| 15 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
-
# copies or substantial portions of the Software.
|
| 17 |
-
#
|
| 18 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
-
# SOFTWARE.
|
| 25 |
-
|
| 26 |
-
import numpy as np
|
| 27 |
-
import torch as th
|
| 28 |
-
import torch.nn as nn
|
| 29 |
-
from torchdiffeq import odeint
|
| 30 |
-
from functools import partial
|
| 31 |
-
from tqdm import tqdm
|
| 32 |
-
|
| 33 |
-
class sde:
|
| 34 |
-
"""SDE solver class"""
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
drift,
|
| 38 |
-
diffusion,
|
| 39 |
-
*,
|
| 40 |
-
t0,
|
| 41 |
-
t1,
|
| 42 |
-
num_steps,
|
| 43 |
-
sampler_type,
|
| 44 |
-
):
|
| 45 |
-
assert t0 < t1, "SDE sampler has to be in forward time"
|
| 46 |
-
|
| 47 |
-
self.num_timesteps = num_steps
|
| 48 |
-
self.t = th.linspace(t0, t1, num_steps)
|
| 49 |
-
self.dt = self.t[1] - self.t[0]
|
| 50 |
-
self.drift = drift
|
| 51 |
-
self.diffusion = diffusion
|
| 52 |
-
self.sampler_type = sampler_type
|
| 53 |
-
|
| 54 |
-
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
| 55 |
-
w_cur = th.randn(x.size()).to(x)
|
| 56 |
-
t = th.ones(x.size(0)).to(x) * t
|
| 57 |
-
dw = w_cur * th.sqrt(self.dt)
|
| 58 |
-
drift = self.drift(x, t, model, **model_kwargs)
|
| 59 |
-
diffusion = self.diffusion(x, t)
|
| 60 |
-
mean_x = x + drift * self.dt
|
| 61 |
-
x = mean_x + th.sqrt(2 * diffusion) * dw
|
| 62 |
-
return x, mean_x
|
| 63 |
-
|
| 64 |
-
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
| 65 |
-
w_cur = th.randn(x.size()).to(x)
|
| 66 |
-
dw = w_cur * th.sqrt(self.dt)
|
| 67 |
-
t_cur = th.ones(x.size(0)).to(x) * t
|
| 68 |
-
diffusion = self.diffusion(x, t_cur)
|
| 69 |
-
xhat = x + th.sqrt(2 * diffusion) * dw
|
| 70 |
-
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
| 71 |
-
xp = xhat + self.dt * K1
|
| 72 |
-
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
| 73 |
-
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
| 74 |
-
|
| 75 |
-
def __forward_fn(self):
|
| 76 |
-
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
| 77 |
-
sampler_dict = {
|
| 78 |
-
"Euler": self.__Euler_Maruyama_step,
|
| 79 |
-
"Heun": self.__Heun_step,
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
try:
|
| 83 |
-
sampler = sampler_dict[self.sampler_type]
|
| 84 |
-
except:
|
| 85 |
-
raise NotImplementedError("Smapler type not implemented.")
|
| 86 |
-
|
| 87 |
-
return sampler
|
| 88 |
-
|
| 89 |
-
def sample(self, init, model, **model_kwargs):
|
| 90 |
-
"""forward loop of sde"""
|
| 91 |
-
x = init
|
| 92 |
-
mean_x = init
|
| 93 |
-
samples = []
|
| 94 |
-
sampler = self.__forward_fn()
|
| 95 |
-
for ti in self.t[:-1]:
|
| 96 |
-
with th.no_grad():
|
| 97 |
-
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 98 |
-
samples.append(x)
|
| 99 |
-
|
| 100 |
-
return samples
|
| 101 |
-
|
| 102 |
-
class ode:
|
| 103 |
-
"""ODE solver class"""
|
| 104 |
-
def __init__(
|
| 105 |
-
self,
|
| 106 |
-
drift,
|
| 107 |
-
*,
|
| 108 |
-
t0,
|
| 109 |
-
t1,
|
| 110 |
-
sampler_type,
|
| 111 |
-
num_steps,
|
| 112 |
-
atol,
|
| 113 |
-
rtol,
|
| 114 |
-
):
|
| 115 |
-
assert t0 < t1, "ODE sampler has to be in forward time"
|
| 116 |
-
|
| 117 |
-
self.drift = drift
|
| 118 |
-
self.t = th.linspace(t0, t1, num_steps)
|
| 119 |
-
self.atol = atol
|
| 120 |
-
self.rtol = rtol
|
| 121 |
-
self.sampler_type = sampler_type
|
| 122 |
-
|
| 123 |
-
def sample(self, x, model, **model_kwargs):
|
| 124 |
-
|
| 125 |
-
device = x[0].device if isinstance(x, tuple) else x.device
|
| 126 |
-
def _fn(t, x):
|
| 127 |
-
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
| 128 |
-
model_output = self.drift(x, t, model, **model_kwargs)
|
| 129 |
-
return model_output
|
| 130 |
-
|
| 131 |
-
t = self.t.to(device)
|
| 132 |
-
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
| 133 |
-
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
| 134 |
-
samples = odeint(
|
| 135 |
-
_fn,
|
| 136 |
-
x,
|
| 137 |
-
t,
|
| 138 |
-
method=self.sampler_type,
|
| 139 |
-
atol=atol,
|
| 140 |
-
rtol=rtol
|
| 141 |
-
)
|
| 142 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/transport/path.py
DELETED
|
@@ -1,220 +0,0 @@
|
|
| 1 |
-
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
-
# which is licensed under the MIT License.
|
| 3 |
-
#
|
| 4 |
-
# MIT License
|
| 5 |
-
#
|
| 6 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
-
#
|
| 8 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
-
# in the Software without restriction, including without limitation the rights
|
| 11 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
-
# furnished to do so, subject to the following conditions:
|
| 14 |
-
#
|
| 15 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
-
# copies or substantial portions of the Software.
|
| 17 |
-
#
|
| 18 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
-
# SOFTWARE.
|
| 25 |
-
|
| 26 |
-
import torch as th
|
| 27 |
-
import numpy as np
|
| 28 |
-
from functools import partial
|
| 29 |
-
|
| 30 |
-
def expand_t_like_x(t, x):
|
| 31 |
-
"""Function to reshape time t to broadcastable dimension of x
|
| 32 |
-
Args:
|
| 33 |
-
t: [batch_dim,], time vector
|
| 34 |
-
x: [batch_dim,...], data point
|
| 35 |
-
"""
|
| 36 |
-
dims = [1] * (len(x.size()) - 1)
|
| 37 |
-
t = t.view(t.size(0), *dims)
|
| 38 |
-
return t
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
#################### Coupling Plans ####################
|
| 42 |
-
|
| 43 |
-
class ICPlan:
|
| 44 |
-
"""Linear Coupling Plan"""
|
| 45 |
-
def __init__(self, sigma=0.0):
|
| 46 |
-
self.sigma = sigma
|
| 47 |
-
|
| 48 |
-
def compute_alpha_t(self, t):
|
| 49 |
-
"""Compute the data coefficient along the path"""
|
| 50 |
-
return t, 1
|
| 51 |
-
|
| 52 |
-
def compute_sigma_t(self, t):
|
| 53 |
-
"""Compute the noise coefficient along the path"""
|
| 54 |
-
return 1 - t, -1
|
| 55 |
-
|
| 56 |
-
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 57 |
-
"""Compute the ratio between d_alpha and alpha"""
|
| 58 |
-
return 1 / t
|
| 59 |
-
|
| 60 |
-
def compute_drift(self, x, t):
|
| 61 |
-
"""We always output sde according to score parametrization; """
|
| 62 |
-
t = expand_t_like_x(t, x)
|
| 63 |
-
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
| 64 |
-
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 65 |
-
drift = alpha_ratio * x
|
| 66 |
-
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
| 67 |
-
|
| 68 |
-
return -drift, diffusion
|
| 69 |
-
|
| 70 |
-
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
| 71 |
-
"""Compute the diffusion term of the SDE
|
| 72 |
-
Args:
|
| 73 |
-
x: [batch_dim, ...], data point
|
| 74 |
-
t: [batch_dim,], time vector
|
| 75 |
-
form: str, form of the diffusion term
|
| 76 |
-
norm: float, norm of the diffusion term
|
| 77 |
-
"""
|
| 78 |
-
t = expand_t_like_x(t, x)
|
| 79 |
-
choices = {
|
| 80 |
-
"constant": norm,
|
| 81 |
-
"SBDM": norm * self.compute_drift(x, t)[1],
|
| 82 |
-
"sigma": norm * self.compute_sigma_t(t)[0],
|
| 83 |
-
"linear": norm * (1 - t),
|
| 84 |
-
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
| 85 |
-
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
try:
|
| 89 |
-
diffusion = choices[form]
|
| 90 |
-
except KeyError:
|
| 91 |
-
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
| 92 |
-
|
| 93 |
-
return diffusion
|
| 94 |
-
|
| 95 |
-
def get_score_from_velocity(self, velocity, x, t):
|
| 96 |
-
"""Wrapper function: transfrom velocity prediction model to score
|
| 97 |
-
Args:
|
| 98 |
-
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 99 |
-
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 100 |
-
t: [batch_dim,] time tensor
|
| 101 |
-
"""
|
| 102 |
-
t = expand_t_like_x(t, x)
|
| 103 |
-
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 104 |
-
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 105 |
-
mean = x
|
| 106 |
-
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 107 |
-
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 108 |
-
score = (reverse_alpha_ratio * velocity - mean) / var
|
| 109 |
-
return score
|
| 110 |
-
|
| 111 |
-
def get_noise_from_velocity(self, velocity, x, t):
|
| 112 |
-
"""Wrapper function: transfrom velocity prediction model to denoiser
|
| 113 |
-
Args:
|
| 114 |
-
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 115 |
-
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 116 |
-
t: [batch_dim,] time tensor
|
| 117 |
-
"""
|
| 118 |
-
t = expand_t_like_x(t, x)
|
| 119 |
-
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 120 |
-
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 121 |
-
mean = x
|
| 122 |
-
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 123 |
-
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
| 124 |
-
noise = (reverse_alpha_ratio * velocity - mean) / var
|
| 125 |
-
return noise
|
| 126 |
-
|
| 127 |
-
def get_velocity_from_score(self, score, x, t):
|
| 128 |
-
"""Wrapper function: transfrom score prediction model to velocity
|
| 129 |
-
Args:
|
| 130 |
-
score: [batch_dim, ...] shaped tensor; score model output
|
| 131 |
-
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 132 |
-
t: [batch_dim,] time tensor
|
| 133 |
-
"""
|
| 134 |
-
t = expand_t_like_x(t, x)
|
| 135 |
-
drift, var = self.compute_drift(x, t)
|
| 136 |
-
velocity = var * score - drift
|
| 137 |
-
return velocity
|
| 138 |
-
|
| 139 |
-
def compute_mu_t(self, t, x0, x1):
|
| 140 |
-
"""Compute the mean of time-dependent density p_t"""
|
| 141 |
-
t = expand_t_like_x(t, x1)
|
| 142 |
-
alpha_t, _ = self.compute_alpha_t(t)
|
| 143 |
-
sigma_t, _ = self.compute_sigma_t(t)
|
| 144 |
-
# t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
|
| 145 |
-
return alpha_t * x1 + sigma_t * x0
|
| 146 |
-
|
| 147 |
-
def compute_xt(self, t, x0, x1):
|
| 148 |
-
"""Sample xt from time-dependent density p_t; rng is required"""
|
| 149 |
-
xt = self.compute_mu_t(t, x0, x1)
|
| 150 |
-
return xt
|
| 151 |
-
|
| 152 |
-
def compute_ut(self, t, x0, x1, xt):
|
| 153 |
-
"""Compute the vector field corresponding to p_t"""
|
| 154 |
-
t = expand_t_like_x(t, x1)
|
| 155 |
-
_, d_alpha_t = self.compute_alpha_t(t)
|
| 156 |
-
_, d_sigma_t = self.compute_sigma_t(t)
|
| 157 |
-
return d_alpha_t * x1 + d_sigma_t * x0
|
| 158 |
-
|
| 159 |
-
def plan(self, t, x0, x1):
|
| 160 |
-
xt = self.compute_xt(t, x0, x1)
|
| 161 |
-
ut = self.compute_ut(t, x0, x1, xt)
|
| 162 |
-
return t, xt, ut
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class VPCPlan(ICPlan):
|
| 166 |
-
"""class for VP path flow matching"""
|
| 167 |
-
|
| 168 |
-
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
| 169 |
-
self.sigma_min = sigma_min
|
| 170 |
-
self.sigma_max = sigma_max
|
| 171 |
-
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
|
| 172 |
-
(self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
| 173 |
-
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
|
| 174 |
-
(self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def compute_alpha_t(self, t):
|
| 178 |
-
"""Compute coefficient of x1"""
|
| 179 |
-
alpha_t = self.log_mean_coeff(t)
|
| 180 |
-
alpha_t = th.exp(alpha_t)
|
| 181 |
-
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
| 182 |
-
return alpha_t, d_alpha_t
|
| 183 |
-
|
| 184 |
-
def compute_sigma_t(self, t):
|
| 185 |
-
"""Compute coefficient of x0"""
|
| 186 |
-
p_sigma_t = 2 * self.log_mean_coeff(t)
|
| 187 |
-
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
| 188 |
-
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
| 189 |
-
return sigma_t, d_sigma_t
|
| 190 |
-
|
| 191 |
-
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 192 |
-
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 193 |
-
return self.d_log_mean_coeff(t)
|
| 194 |
-
|
| 195 |
-
def compute_drift(self, x, t):
|
| 196 |
-
"""Compute the drift term of the SDE"""
|
| 197 |
-
t = expand_t_like_x(t, x)
|
| 198 |
-
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
| 199 |
-
return -0.5 * beta_t * x, beta_t / 2
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
class GVPCPlan(ICPlan):
|
| 203 |
-
def __init__(self, sigma=0.0):
|
| 204 |
-
super().__init__(sigma)
|
| 205 |
-
|
| 206 |
-
def compute_alpha_t(self, t):
|
| 207 |
-
"""Compute coefficient of x1"""
|
| 208 |
-
alpha_t = th.sin(t * np.pi / 2)
|
| 209 |
-
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
| 210 |
-
return alpha_t, d_alpha_t
|
| 211 |
-
|
| 212 |
-
def compute_sigma_t(self, t):
|
| 213 |
-
"""Compute coefficient of x0"""
|
| 214 |
-
sigma_t = th.cos(t * np.pi / 2)
|
| 215 |
-
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
| 216 |
-
return sigma_t, d_sigma_t
|
| 217 |
-
|
| 218 |
-
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 219 |
-
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 220 |
-
return np.pi / (2 * th.tan(t * np.pi / 2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/transport/transport.py
DELETED
|
@@ -1,534 +0,0 @@
|
|
| 1 |
-
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
-
# which is licensed under the MIT License.
|
| 3 |
-
#
|
| 4 |
-
# MIT License
|
| 5 |
-
#
|
| 6 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
-
#
|
| 8 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
-
# in the Software without restriction, including without limitation the rights
|
| 11 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
-
# furnished to do so, subject to the following conditions:
|
| 14 |
-
#
|
| 15 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
-
# copies or substantial portions of the Software.
|
| 17 |
-
#
|
| 18 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
-
# SOFTWARE.
|
| 25 |
-
|
| 26 |
-
import torch as th
|
| 27 |
-
import numpy as np
|
| 28 |
-
import logging
|
| 29 |
-
|
| 30 |
-
import enum
|
| 31 |
-
|
| 32 |
-
from . import path
|
| 33 |
-
from .utils import EasyDict, log_state, mean_flat
|
| 34 |
-
from .integrators import ode, sde
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class ModelType(enum.Enum):
|
| 38 |
-
"""
|
| 39 |
-
Which type of output the model predicts.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
NOISE = enum.auto() # the model predicts epsilon
|
| 43 |
-
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
| 44 |
-
VELOCITY = enum.auto() # the model predicts v(x)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class PathType(enum.Enum):
|
| 48 |
-
"""
|
| 49 |
-
Which type of path to use.
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
LINEAR = enum.auto()
|
| 53 |
-
GVP = enum.auto()
|
| 54 |
-
VP = enum.auto()
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class WeightType(enum.Enum):
|
| 58 |
-
"""
|
| 59 |
-
Which type of weighting to use.
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
NONE = enum.auto()
|
| 63 |
-
VELOCITY = enum.auto()
|
| 64 |
-
LIKELIHOOD = enum.auto()
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
class Transport:
|
| 68 |
-
|
| 69 |
-
def __init__(
|
| 70 |
-
self,
|
| 71 |
-
*,
|
| 72 |
-
model_type,
|
| 73 |
-
path_type,
|
| 74 |
-
loss_type,
|
| 75 |
-
train_eps,
|
| 76 |
-
sample_eps,
|
| 77 |
-
train_sample_type = "uniform",
|
| 78 |
-
**kwargs,
|
| 79 |
-
):
|
| 80 |
-
path_options = {
|
| 81 |
-
PathType.LINEAR: path.ICPlan,
|
| 82 |
-
PathType.GVP: path.GVPCPlan,
|
| 83 |
-
PathType.VP: path.VPCPlan,
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
self.loss_type = loss_type
|
| 87 |
-
self.model_type = model_type
|
| 88 |
-
self.path_sampler = path_options[path_type]()
|
| 89 |
-
self.train_eps = train_eps
|
| 90 |
-
self.sample_eps = sample_eps
|
| 91 |
-
self.train_sample_type = train_sample_type
|
| 92 |
-
if self.train_sample_type == "logit_normal":
|
| 93 |
-
self.mean = kwargs['mean']
|
| 94 |
-
self.std = kwargs['std']
|
| 95 |
-
self.shift_scale = kwargs['shift_scale']
|
| 96 |
-
print(f"using logit normal sample, shift scale is {self.shift_scale}")
|
| 97 |
-
|
| 98 |
-
def prior_logp(self, z):
|
| 99 |
-
'''
|
| 100 |
-
Standard multivariate normal prior
|
| 101 |
-
Assume z is batched
|
| 102 |
-
'''
|
| 103 |
-
shape = th.tensor(z.size())
|
| 104 |
-
N = th.prod(shape[1:])
|
| 105 |
-
_fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
|
| 106 |
-
return th.vmap(_fn)(z)
|
| 107 |
-
|
| 108 |
-
def check_interval(
|
| 109 |
-
self,
|
| 110 |
-
train_eps,
|
| 111 |
-
sample_eps,
|
| 112 |
-
*,
|
| 113 |
-
diffusion_form="SBDM",
|
| 114 |
-
sde=False,
|
| 115 |
-
reverse=False,
|
| 116 |
-
eval=False,
|
| 117 |
-
last_step_size=0.0,
|
| 118 |
-
):
|
| 119 |
-
t0 = 0
|
| 120 |
-
t1 = 1
|
| 121 |
-
eps = train_eps if not eval else sample_eps
|
| 122 |
-
if (type(self.path_sampler) in [path.VPCPlan]):
|
| 123 |
-
|
| 124 |
-
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 125 |
-
|
| 126 |
-
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
|
| 127 |
-
and (
|
| 128 |
-
self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
|
| 129 |
-
|
| 130 |
-
t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
|
| 131 |
-
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 132 |
-
|
| 133 |
-
if reverse:
|
| 134 |
-
t0, t1 = 1 - t0, 1 - t1
|
| 135 |
-
|
| 136 |
-
return t0, t1
|
| 137 |
-
|
| 138 |
-
def sample(self, x1):
|
| 139 |
-
"""Sampling x0 & t based on shape of x1 (if needed)
|
| 140 |
-
Args:
|
| 141 |
-
x1 - data point; [batch, *dim]
|
| 142 |
-
"""
|
| 143 |
-
|
| 144 |
-
x0 = th.randn_like(x1)
|
| 145 |
-
if self.train_sample_type=="uniform":
|
| 146 |
-
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
| 147 |
-
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
| 148 |
-
t = t.to(x1)
|
| 149 |
-
elif self.train_sample_type=="logit_normal":
|
| 150 |
-
t = th.randn((x1.shape[0],)) * self.std + self.mean
|
| 151 |
-
t = t.to(x1)
|
| 152 |
-
t = 1/(1+th.exp(-t))
|
| 153 |
-
|
| 154 |
-
t = np.sqrt(self.shift_scale)*t/(1+(np.sqrt(self.shift_scale)-1)*t)
|
| 155 |
-
|
| 156 |
-
return t, x0, x1
|
| 157 |
-
|
| 158 |
-
def training_losses(
|
| 159 |
-
self,
|
| 160 |
-
model,
|
| 161 |
-
x1,
|
| 162 |
-
model_kwargs=None
|
| 163 |
-
):
|
| 164 |
-
"""Loss for training the score model
|
| 165 |
-
Args:
|
| 166 |
-
- model: backbone model; could be score, noise, or velocity
|
| 167 |
-
- x1: datapoint
|
| 168 |
-
- model_kwargs: additional arguments for the model
|
| 169 |
-
"""
|
| 170 |
-
if model_kwargs == None:
|
| 171 |
-
model_kwargs = {}
|
| 172 |
-
|
| 173 |
-
t, x0, x1 = self.sample(x1)
|
| 174 |
-
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
| 175 |
-
model_output = model(xt, t, **model_kwargs)
|
| 176 |
-
B, *_, C = xt.shape
|
| 177 |
-
assert model_output.size() == (B, *xt.size()[1:-1], C)
|
| 178 |
-
|
| 179 |
-
terms = {}
|
| 180 |
-
terms['pred'] = model_output
|
| 181 |
-
if self.model_type == ModelType.VELOCITY:
|
| 182 |
-
terms['loss'] = mean_flat(((model_output - ut) ** 2))
|
| 183 |
-
else:
|
| 184 |
-
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
| 185 |
-
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
| 186 |
-
if self.loss_type in [WeightType.VELOCITY]:
|
| 187 |
-
weight = (drift_var / sigma_t) ** 2
|
| 188 |
-
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
| 189 |
-
weight = drift_var / (sigma_t ** 2)
|
| 190 |
-
elif self.loss_type in [WeightType.NONE]:
|
| 191 |
-
weight = 1
|
| 192 |
-
else:
|
| 193 |
-
raise NotImplementedError()
|
| 194 |
-
|
| 195 |
-
if self.model_type == ModelType.NOISE:
|
| 196 |
-
terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
|
| 197 |
-
else:
|
| 198 |
-
terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
| 199 |
-
|
| 200 |
-
return terms
|
| 201 |
-
|
| 202 |
-
def get_drift(
|
| 203 |
-
self
|
| 204 |
-
):
|
| 205 |
-
"""member function for obtaining the drift of the probability flow ODE"""
|
| 206 |
-
|
| 207 |
-
def score_ode(x, t, model, **model_kwargs):
|
| 208 |
-
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 209 |
-
model_output = model(x, t, **model_kwargs)
|
| 210 |
-
return (-drift_mean + drift_var * model_output) # by change of variable
|
| 211 |
-
|
| 212 |
-
def noise_ode(x, t, model, **model_kwargs):
|
| 213 |
-
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 214 |
-
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
| 215 |
-
model_output = model(x, t, **model_kwargs)
|
| 216 |
-
score = model_output / -sigma_t
|
| 217 |
-
return (-drift_mean + drift_var * score)
|
| 218 |
-
|
| 219 |
-
def velocity_ode(x, t, model, **model_kwargs):
|
| 220 |
-
model_output = model(x, t, **model_kwargs)
|
| 221 |
-
return model_output
|
| 222 |
-
|
| 223 |
-
if self.model_type == ModelType.NOISE:
|
| 224 |
-
drift_fn = noise_ode
|
| 225 |
-
elif self.model_type == ModelType.SCORE:
|
| 226 |
-
drift_fn = score_ode
|
| 227 |
-
else:
|
| 228 |
-
drift_fn = velocity_ode
|
| 229 |
-
|
| 230 |
-
def body_fn(x, t, model, **model_kwargs):
|
| 231 |
-
model_output = drift_fn(x, t, model, **model_kwargs)
|
| 232 |
-
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
| 233 |
-
return model_output
|
| 234 |
-
|
| 235 |
-
return body_fn
|
| 236 |
-
|
| 237 |
-
def get_score(
|
| 238 |
-
self,
|
| 239 |
-
):
|
| 240 |
-
"""member function for obtaining score of
|
| 241 |
-
x_t = alpha_t * x + sigma_t * eps"""
|
| 242 |
-
if self.model_type == ModelType.NOISE:
|
| 243 |
-
score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / - \
|
| 244 |
-
self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
| 245 |
-
elif self.model_type == ModelType.SCORE:
|
| 246 |
-
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
| 247 |
-
elif self.model_type == ModelType.VELOCITY:
|
| 248 |
-
score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x,
|
| 249 |
-
t)
|
| 250 |
-
else:
|
| 251 |
-
raise NotImplementedError()
|
| 252 |
-
|
| 253 |
-
return score_fn
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
class Sampler:
|
| 257 |
-
"""Sampler class for the transport model"""
|
| 258 |
-
|
| 259 |
-
def __init__(
|
| 260 |
-
self,
|
| 261 |
-
transport,
|
| 262 |
-
):
|
| 263 |
-
"""Constructor for a general sampler; supporting different sampling methods
|
| 264 |
-
Args:
|
| 265 |
-
- transport: an tranport object specify model prediction & interpolant type
|
| 266 |
-
"""
|
| 267 |
-
|
| 268 |
-
self.transport = transport
|
| 269 |
-
self.drift = self.transport.get_drift()
|
| 270 |
-
self.score = self.transport.get_score()
|
| 271 |
-
|
| 272 |
-
def __get_sde_diffusion_and_drift(
|
| 273 |
-
self,
|
| 274 |
-
*,
|
| 275 |
-
diffusion_form="SBDM",
|
| 276 |
-
diffusion_norm=1.0,
|
| 277 |
-
):
|
| 278 |
-
|
| 279 |
-
def diffusion_fn(x, t):
|
| 280 |
-
diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
|
| 281 |
-
return diffusion
|
| 282 |
-
|
| 283 |
-
sde_drift = \
|
| 284 |
-
lambda x, t, model, **kwargs: \
|
| 285 |
-
self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 286 |
-
|
| 287 |
-
sde_diffusion = diffusion_fn
|
| 288 |
-
|
| 289 |
-
return sde_drift, sde_diffusion
|
| 290 |
-
|
| 291 |
-
def __get_last_step(
|
| 292 |
-
self,
|
| 293 |
-
sde_drift,
|
| 294 |
-
*,
|
| 295 |
-
last_step,
|
| 296 |
-
last_step_size,
|
| 297 |
-
):
|
| 298 |
-
"""Get the last step function of the SDE solver"""
|
| 299 |
-
|
| 300 |
-
if last_step is None:
|
| 301 |
-
last_step_fn = \
|
| 302 |
-
lambda x, t, model, **model_kwargs: \
|
| 303 |
-
x
|
| 304 |
-
elif last_step == "Mean":
|
| 305 |
-
last_step_fn = \
|
| 306 |
-
lambda x, t, model, **model_kwargs: \
|
| 307 |
-
x + sde_drift(x, t, model, **model_kwargs) * last_step_size
|
| 308 |
-
elif last_step == "Tweedie":
|
| 309 |
-
alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
|
| 310 |
-
sigma = self.transport.path_sampler.compute_sigma_t
|
| 311 |
-
last_step_fn = \
|
| 312 |
-
lambda x, t, model, **model_kwargs: \
|
| 313 |
-
x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model,
|
| 314 |
-
**model_kwargs)
|
| 315 |
-
elif last_step == "Euler":
|
| 316 |
-
last_step_fn = \
|
| 317 |
-
lambda x, t, model, **model_kwargs: \
|
| 318 |
-
x + self.drift(x, t, model, **model_kwargs) * last_step_size
|
| 319 |
-
else:
|
| 320 |
-
raise NotImplementedError()
|
| 321 |
-
|
| 322 |
-
return last_step_fn
|
| 323 |
-
|
| 324 |
-
def sample_sde(
|
| 325 |
-
self,
|
| 326 |
-
*,
|
| 327 |
-
sampling_method="Euler",
|
| 328 |
-
diffusion_form="SBDM",
|
| 329 |
-
diffusion_norm=1.0,
|
| 330 |
-
last_step="Mean",
|
| 331 |
-
last_step_size=0.04,
|
| 332 |
-
num_steps=250,
|
| 333 |
-
):
|
| 334 |
-
"""returns a sampling function with given SDE settings
|
| 335 |
-
Args:
|
| 336 |
-
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
| 337 |
-
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
| 338 |
-
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
| 339 |
-
- last_step: type of the last step; default to identity
|
| 340 |
-
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
| 341 |
-
- num_steps: total integration step of SDE
|
| 342 |
-
"""
|
| 343 |
-
|
| 344 |
-
if last_step is None:
|
| 345 |
-
last_step_size = 0.0
|
| 346 |
-
|
| 347 |
-
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
| 348 |
-
diffusion_form=diffusion_form,
|
| 349 |
-
diffusion_norm=diffusion_norm,
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
t0, t1 = self.transport.check_interval(
|
| 353 |
-
self.transport.train_eps,
|
| 354 |
-
self.transport.sample_eps,
|
| 355 |
-
diffusion_form=diffusion_form,
|
| 356 |
-
sde=True,
|
| 357 |
-
eval=True,
|
| 358 |
-
reverse=False,
|
| 359 |
-
last_step_size=last_step_size,
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
_sde = sde(
|
| 363 |
-
sde_drift,
|
| 364 |
-
sde_diffusion,
|
| 365 |
-
t0=t0,
|
| 366 |
-
t1=t1,
|
| 367 |
-
num_steps=num_steps,
|
| 368 |
-
sampler_type=sampling_method
|
| 369 |
-
)
|
| 370 |
-
|
| 371 |
-
last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
|
| 372 |
-
|
| 373 |
-
def _sample(init, model, **model_kwargs):
|
| 374 |
-
xs = _sde.sample(init, model, **model_kwargs)
|
| 375 |
-
ts = th.ones(init.size(0), device=init.device) * t1
|
| 376 |
-
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
| 377 |
-
xs.append(x)
|
| 378 |
-
|
| 379 |
-
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
| 380 |
-
|
| 381 |
-
return xs
|
| 382 |
-
|
| 383 |
-
return _sample
|
| 384 |
-
|
| 385 |
-
def sample_ode(
|
| 386 |
-
self,
|
| 387 |
-
*,
|
| 388 |
-
sampling_method="dopri5",
|
| 389 |
-
num_steps=50,
|
| 390 |
-
atol=1e-6,
|
| 391 |
-
rtol=1e-3,
|
| 392 |
-
reverse=False,
|
| 393 |
-
):
|
| 394 |
-
"""returns a sampling function with given ODE settings
|
| 395 |
-
Args:
|
| 396 |
-
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 397 |
-
- num_steps:
|
| 398 |
-
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 399 |
-
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 400 |
-
- atol: absolute error tolerance for the solver
|
| 401 |
-
- rtol: relative error tolerance for the solver
|
| 402 |
-
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
| 403 |
-
"""
|
| 404 |
-
if reverse:
|
| 405 |
-
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
| 406 |
-
else:
|
| 407 |
-
drift = self.drift
|
| 408 |
-
|
| 409 |
-
t0, t1 = self.transport.check_interval(
|
| 410 |
-
self.transport.train_eps,
|
| 411 |
-
self.transport.sample_eps,
|
| 412 |
-
sde=False,
|
| 413 |
-
eval=True,
|
| 414 |
-
reverse=reverse,
|
| 415 |
-
last_step_size=0.0,
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
_ode = ode(
|
| 419 |
-
drift=drift,
|
| 420 |
-
t0=t0,
|
| 421 |
-
t1=t1,
|
| 422 |
-
sampler_type=sampling_method,
|
| 423 |
-
num_steps=num_steps,
|
| 424 |
-
atol=atol,
|
| 425 |
-
rtol=rtol,
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
return _ode.sample
|
| 429 |
-
|
| 430 |
-
def sample_ode_intermediate(
|
| 431 |
-
self,
|
| 432 |
-
*,
|
| 433 |
-
sampling_method="dopri5",
|
| 434 |
-
num_steps=50,
|
| 435 |
-
atol=1e-6,
|
| 436 |
-
rtol=1e-3,
|
| 437 |
-
t=0.5,
|
| 438 |
-
reverse=False,
|
| 439 |
-
):
|
| 440 |
-
"""returns a sampling function with given ODE settings
|
| 441 |
-
Args:
|
| 442 |
-
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 443 |
-
- num_steps:
|
| 444 |
-
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 445 |
-
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 446 |
-
- atol: absolute error tolerance for the solver
|
| 447 |
-
- rtol: relative error tolerance for the solver
|
| 448 |
-
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
| 449 |
-
"""
|
| 450 |
-
if reverse:
|
| 451 |
-
drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
|
| 452 |
-
else:
|
| 453 |
-
drift = self.drift
|
| 454 |
-
|
| 455 |
-
t0, t1 = self.transport.check_interval(
|
| 456 |
-
self.transport.train_eps,
|
| 457 |
-
self.transport.sample_eps,
|
| 458 |
-
sde=False,
|
| 459 |
-
eval=True,
|
| 460 |
-
reverse=reverse,
|
| 461 |
-
last_step_size=0.0,
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
_ode = ode(
|
| 465 |
-
drift=drift,
|
| 466 |
-
t0=t,
|
| 467 |
-
t1=t1,
|
| 468 |
-
sampler_type=sampling_method,
|
| 469 |
-
num_steps=num_steps,
|
| 470 |
-
atol=atol,
|
| 471 |
-
rtol=rtol,
|
| 472 |
-
)
|
| 473 |
-
|
| 474 |
-
return _ode.sample
|
| 475 |
-
|
| 476 |
-
def sample_ode_likelihood(
|
| 477 |
-
self,
|
| 478 |
-
*,
|
| 479 |
-
sampling_method="dopri5",
|
| 480 |
-
num_steps=50,
|
| 481 |
-
atol=1e-6,
|
| 482 |
-
rtol=1e-3,
|
| 483 |
-
):
|
| 484 |
-
|
| 485 |
-
"""returns a sampling function for calculating likelihood with given ODE settings
|
| 486 |
-
Args:
|
| 487 |
-
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 488 |
-
- num_steps:
|
| 489 |
-
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 490 |
-
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 491 |
-
- atol: absolute error tolerance for the solver
|
| 492 |
-
- rtol: relative error tolerance for the solver
|
| 493 |
-
"""
|
| 494 |
-
|
| 495 |
-
def _likelihood_drift(x, t, model, **model_kwargs):
|
| 496 |
-
x, _ = x
|
| 497 |
-
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
| 498 |
-
t = th.ones_like(t) * (1 - t)
|
| 499 |
-
with th.enable_grad():
|
| 500 |
-
x.requires_grad = True
|
| 501 |
-
grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
|
| 502 |
-
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
| 503 |
-
drift = self.drift(x, t, model, **model_kwargs)
|
| 504 |
-
return (-drift, logp_grad)
|
| 505 |
-
|
| 506 |
-
t0, t1 = self.transport.check_interval(
|
| 507 |
-
self.transport.train_eps,
|
| 508 |
-
self.transport.sample_eps,
|
| 509 |
-
sde=False,
|
| 510 |
-
eval=True,
|
| 511 |
-
reverse=False,
|
| 512 |
-
last_step_size=0.0,
|
| 513 |
-
)
|
| 514 |
-
|
| 515 |
-
_ode = ode(
|
| 516 |
-
drift=_likelihood_drift,
|
| 517 |
-
t0=t0,
|
| 518 |
-
t1=t1,
|
| 519 |
-
sampler_type=sampling_method,
|
| 520 |
-
num_steps=num_steps,
|
| 521 |
-
atol=atol,
|
| 522 |
-
rtol=rtol,
|
| 523 |
-
)
|
| 524 |
-
|
| 525 |
-
def _sample_fn(x, model, **model_kwargs):
|
| 526 |
-
init_logp = th.zeros(x.size(0)).to(x)
|
| 527 |
-
input = (x, init_logp)
|
| 528 |
-
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
| 529 |
-
drift, delta_logp = drift[-1], delta_logp[-1]
|
| 530 |
-
prior_logp = self.transport.prior_logp(drift)
|
| 531 |
-
logp = prior_logp - delta_logp
|
| 532 |
-
return logp, drift
|
| 533 |
-
|
| 534 |
-
return _sample_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/models/diffusion/transport/utils.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
-
# which is licensed under the MIT License.
|
| 3 |
-
#
|
| 4 |
-
# MIT License
|
| 5 |
-
#
|
| 6 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
-
#
|
| 8 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
-
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
-
# in the Software without restriction, including without limitation the rights
|
| 11 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
-
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
-
# furnished to do so, subject to the following conditions:
|
| 14 |
-
#
|
| 15 |
-
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
-
# copies or substantial portions of the Software.
|
| 17 |
-
#
|
| 18 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
-
# SOFTWARE.
|
| 25 |
-
|
| 26 |
-
import torch as th
|
| 27 |
-
|
| 28 |
-
class EasyDict:
|
| 29 |
-
|
| 30 |
-
def __init__(self, sub_dict):
|
| 31 |
-
for k, v in sub_dict.items():
|
| 32 |
-
setattr(self, k, v)
|
| 33 |
-
|
| 34 |
-
def __getitem__(self, key):
|
| 35 |
-
return getattr(self, key)
|
| 36 |
-
|
| 37 |
-
def mean_flat(x):
|
| 38 |
-
"""
|
| 39 |
-
Take the mean over all non-batch dimensions.
|
| 40 |
-
"""
|
| 41 |
-
return th.mean(x, dim=list(range(1, len(x.size()))))
|
| 42 |
-
|
| 43 |
-
def log_state(state):
|
| 44 |
-
result = []
|
| 45 |
-
|
| 46 |
-
sorted_state = dict(sorted(state.items()))
|
| 47 |
-
for key, value in sorted_state.items():
|
| 48 |
-
# Check if the value is an instance of a class
|
| 49 |
-
if "<object" in str(value) or "object at" in str(value):
|
| 50 |
-
result.append(f"{key}: [{value.__class__.__name__}]")
|
| 51 |
-
else:
|
| 52 |
-
result.append(f"{key}: {value}")
|
| 53 |
-
|
| 54 |
-
return '\n'.join(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/pipelines.py
DELETED
|
@@ -1,797 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
import copy
|
| 23 |
-
import importlib
|
| 24 |
-
import inspect
|
| 25 |
-
import os
|
| 26 |
-
from typing import List, Optional, Union
|
| 27 |
-
|
| 28 |
-
import numpy as np
|
| 29 |
-
import torch
|
| 30 |
-
import trimesh
|
| 31 |
-
import yaml
|
| 32 |
-
from PIL import Image
|
| 33 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 34 |
-
from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
|
| 35 |
-
from tqdm import tqdm
|
| 36 |
-
|
| 37 |
-
from .models.autoencoders import ShapeVAE
|
| 38 |
-
from .models.autoencoders import SurfaceExtractors
|
| 39 |
-
from .utils import logger, synchronize_timer, smart_load_model
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def retrieve_timesteps(
|
| 43 |
-
scheduler,
|
| 44 |
-
num_inference_steps: Optional[int] = None,
|
| 45 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 46 |
-
timesteps: Optional[List[int]] = None,
|
| 47 |
-
sigmas: Optional[List[float]] = None,
|
| 48 |
-
**kwargs,
|
| 49 |
-
):
|
| 50 |
-
"""
|
| 51 |
-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 52 |
-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
scheduler (`SchedulerMixin`):
|
| 56 |
-
The scheduler to get timesteps from.
|
| 57 |
-
num_inference_steps (`int`):
|
| 58 |
-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 59 |
-
must be `None`.
|
| 60 |
-
device (`str` or `torch.device`, *optional*):
|
| 61 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 62 |
-
timesteps (`List[int]`, *optional*):
|
| 63 |
-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 64 |
-
`num_inference_steps` and `sigmas` must be `None`.
|
| 65 |
-
sigmas (`List[float]`, *optional*):
|
| 66 |
-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 67 |
-
`num_inference_steps` and `timesteps` must be `None`.
|
| 68 |
-
|
| 69 |
-
Returns:
|
| 70 |
-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 71 |
-
second element is the number of inference steps.
|
| 72 |
-
"""
|
| 73 |
-
if timesteps is not None and sigmas is not None:
|
| 74 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 75 |
-
if timesteps is not None:
|
| 76 |
-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 77 |
-
if not accepts_timesteps:
|
| 78 |
-
raise ValueError(
|
| 79 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 80 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 81 |
-
)
|
| 82 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 83 |
-
timesteps = scheduler.timesteps
|
| 84 |
-
num_inference_steps = len(timesteps)
|
| 85 |
-
elif sigmas is not None:
|
| 86 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 87 |
-
if not accept_sigmas:
|
| 88 |
-
raise ValueError(
|
| 89 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 90 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 91 |
-
)
|
| 92 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 93 |
-
timesteps = scheduler.timesteps
|
| 94 |
-
num_inference_steps = len(timesteps)
|
| 95 |
-
else:
|
| 96 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 97 |
-
timesteps = scheduler.timesteps
|
| 98 |
-
return timesteps, num_inference_steps
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
@synchronize_timer('Export to trimesh')
|
| 102 |
-
def export_to_trimesh(mesh_output):
|
| 103 |
-
if isinstance(mesh_output, list):
|
| 104 |
-
outputs = []
|
| 105 |
-
for mesh in mesh_output:
|
| 106 |
-
if mesh is None:
|
| 107 |
-
outputs.append(None)
|
| 108 |
-
else:
|
| 109 |
-
mesh.mesh_f = mesh.mesh_f[:, ::-1]
|
| 110 |
-
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
|
| 111 |
-
outputs.append(mesh_output)
|
| 112 |
-
return outputs
|
| 113 |
-
else:
|
| 114 |
-
mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
|
| 115 |
-
mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
|
| 116 |
-
return mesh_output
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def get_obj_from_str(string, reload=False):
|
| 120 |
-
module, cls = string.rsplit(".", 1)
|
| 121 |
-
if reload:
|
| 122 |
-
module_imp = importlib.import_module(module)
|
| 123 |
-
importlib.reload(module_imp)
|
| 124 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def instantiate_from_config(config, **kwargs):
|
| 128 |
-
if "target" not in config:
|
| 129 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 130 |
-
cls = get_obj_from_str(config["target"])
|
| 131 |
-
params = config.get("params", dict())
|
| 132 |
-
kwargs.update(params)
|
| 133 |
-
instance = cls(**kwargs)
|
| 134 |
-
return instance
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class DiTPipeline:
|
| 138 |
-
model_cpu_offload_seq = "conditioner->model->vae"
|
| 139 |
-
_exclude_from_cpu_offload = []
|
| 140 |
-
|
| 141 |
-
@classmethod
|
| 142 |
-
@synchronize_timer('DiTPipeline Model Loading')
|
| 143 |
-
def from_single_file(
|
| 144 |
-
cls,
|
| 145 |
-
ckpt_path,
|
| 146 |
-
config_path,
|
| 147 |
-
device='cuda',
|
| 148 |
-
dtype=torch.float16,
|
| 149 |
-
use_safetensors=None,
|
| 150 |
-
**kwargs,
|
| 151 |
-
):
|
| 152 |
-
# load config
|
| 153 |
-
with open(config_path, 'r') as f:
|
| 154 |
-
config = yaml.safe_load(f)
|
| 155 |
-
|
| 156 |
-
# load ckpt
|
| 157 |
-
if use_safetensors:
|
| 158 |
-
ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
|
| 159 |
-
if not os.path.exists(ckpt_path):
|
| 160 |
-
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
| 161 |
-
logger.info(f"Loading model from {ckpt_path}")
|
| 162 |
-
|
| 163 |
-
if use_safetensors:
|
| 164 |
-
# parse safetensors
|
| 165 |
-
import safetensors.torch
|
| 166 |
-
safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
|
| 167 |
-
ckpt = {}
|
| 168 |
-
for key, value in safetensors_ckpt.items():
|
| 169 |
-
model_name = key.split('.')[0]
|
| 170 |
-
new_key = key[len(model_name) + 1:]
|
| 171 |
-
if model_name not in ckpt:
|
| 172 |
-
ckpt[model_name] = {}
|
| 173 |
-
ckpt[model_name][new_key] = value
|
| 174 |
-
else:
|
| 175 |
-
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
| 176 |
-
# load model
|
| 177 |
-
model = instantiate_from_config(config['model'])
|
| 178 |
-
model.load_state_dict(ckpt['model'])
|
| 179 |
-
vae = instantiate_from_config(config['vae'])
|
| 180 |
-
vae.load_state_dict(ckpt['vae'], strict=False)
|
| 181 |
-
conditioner = instantiate_from_config(config['conditioner'])
|
| 182 |
-
if 'conditioner' in ckpt:
|
| 183 |
-
conditioner.load_state_dict(ckpt['conditioner'])
|
| 184 |
-
image_processor = instantiate_from_config(config['image_processor'])
|
| 185 |
-
scheduler = instantiate_from_config(config['scheduler'])
|
| 186 |
-
|
| 187 |
-
model_kwargs = dict(
|
| 188 |
-
vae=vae,
|
| 189 |
-
model=model,
|
| 190 |
-
scheduler=scheduler,
|
| 191 |
-
conditioner=conditioner,
|
| 192 |
-
image_processor=image_processor,
|
| 193 |
-
device=device,
|
| 194 |
-
dtype=dtype,
|
| 195 |
-
)
|
| 196 |
-
model_kwargs.update(kwargs)
|
| 197 |
-
|
| 198 |
-
return cls(
|
| 199 |
-
**model_kwargs
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
@classmethod
|
| 203 |
-
def from_pretrained(
|
| 204 |
-
cls,
|
| 205 |
-
model_path,
|
| 206 |
-
device='cuda',
|
| 207 |
-
dtype=torch.float16,
|
| 208 |
-
use_safetensors=False,
|
| 209 |
-
variant='fp16',
|
| 210 |
-
subfolder='hunyuan3d-dit-v2-1',
|
| 211 |
-
**kwargs,
|
| 212 |
-
):
|
| 213 |
-
kwargs['from_pretrained_kwargs'] = dict(
|
| 214 |
-
model_path=model_path,
|
| 215 |
-
subfolder=subfolder,
|
| 216 |
-
use_safetensors=use_safetensors,
|
| 217 |
-
variant=variant,
|
| 218 |
-
dtype=dtype,
|
| 219 |
-
device=device,
|
| 220 |
-
)
|
| 221 |
-
config_path, ckpt_path = smart_load_model(
|
| 222 |
-
model_path,
|
| 223 |
-
subfolder=subfolder,
|
| 224 |
-
use_safetensors=use_safetensors,
|
| 225 |
-
variant=variant
|
| 226 |
-
)
|
| 227 |
-
return cls.from_single_file(
|
| 228 |
-
ckpt_path,
|
| 229 |
-
config_path,
|
| 230 |
-
device=device,
|
| 231 |
-
dtype=dtype,
|
| 232 |
-
use_safetensors=use_safetensors,
|
| 233 |
-
**kwargs
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
vae,
|
| 239 |
-
model,
|
| 240 |
-
scheduler,
|
| 241 |
-
conditioner,
|
| 242 |
-
image_processor,
|
| 243 |
-
device='cuda',
|
| 244 |
-
dtype=torch.float16,
|
| 245 |
-
ref_model=None,
|
| 246 |
-
**kwargs
|
| 247 |
-
):
|
| 248 |
-
self.vae = vae
|
| 249 |
-
self.model = model
|
| 250 |
-
self.ref_model = ref_model
|
| 251 |
-
self.scheduler = scheduler
|
| 252 |
-
self.conditioner = conditioner
|
| 253 |
-
self.image_processor = image_processor
|
| 254 |
-
self.kwargs = kwargs
|
| 255 |
-
|
| 256 |
-
self.components = {
|
| 257 |
-
"vae": vae,
|
| 258 |
-
"model": model,
|
| 259 |
-
"scheduler": scheduler,
|
| 260 |
-
"conditioner": conditioner,
|
| 261 |
-
"image_processor": image_processor,
|
| 262 |
-
}
|
| 263 |
-
if ref_model is not None:
|
| 264 |
-
self.components["ref_model"] = ref_model
|
| 265 |
-
|
| 266 |
-
self.to(device, dtype)
|
| 267 |
-
|
| 268 |
-
def compile(self):
|
| 269 |
-
self.vae = torch.compile(self.vae)
|
| 270 |
-
self.model = torch.compile(self.model)
|
| 271 |
-
self.conditioner = torch.compile(self.conditioner)
|
| 272 |
-
|
| 273 |
-
def enable_flashvdm(
|
| 274 |
-
self,
|
| 275 |
-
enabled: bool = True,
|
| 276 |
-
adaptive_kv_selection=True,
|
| 277 |
-
topk_mode='mean',
|
| 278 |
-
mc_algo='mc',
|
| 279 |
-
replace_vae=True,
|
| 280 |
-
):
|
| 281 |
-
if enabled:
|
| 282 |
-
self.vae.enable_flashvdm_decoder(
|
| 283 |
-
enabled=enabled,
|
| 284 |
-
adaptive_kv_selection=adaptive_kv_selection,
|
| 285 |
-
topk_mode=topk_mode,
|
| 286 |
-
mc_algo=mc_algo
|
| 287 |
-
)
|
| 288 |
-
else:
|
| 289 |
-
self.vae.enable_flashvdm_decoder(enabled=False)
|
| 290 |
-
|
| 291 |
-
def to(self, device=None, dtype=None):
|
| 292 |
-
if dtype is not None:
|
| 293 |
-
self.dtype = dtype
|
| 294 |
-
self.vae.to(dtype=dtype)
|
| 295 |
-
self.model.to(dtype=dtype)
|
| 296 |
-
self.conditioner.to(dtype=dtype)
|
| 297 |
-
if device is not None:
|
| 298 |
-
self.device = torch.device(device)
|
| 299 |
-
self.vae.to(device)
|
| 300 |
-
self.model.to(device)
|
| 301 |
-
self.conditioner.to(device)
|
| 302 |
-
|
| 303 |
-
@property
|
| 304 |
-
def _execution_device(self):
|
| 305 |
-
r"""
|
| 306 |
-
Returns the device on which the pipeline's models will be executed. After calling
|
| 307 |
-
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
| 308 |
-
Accelerate's module hooks.
|
| 309 |
-
"""
|
| 310 |
-
for name, model in self.components.items():
|
| 311 |
-
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
|
| 312 |
-
continue
|
| 313 |
-
|
| 314 |
-
if not hasattr(model, "_hf_hook"):
|
| 315 |
-
return self.device
|
| 316 |
-
for module in model.modules():
|
| 317 |
-
if (
|
| 318 |
-
hasattr(module, "_hf_hook")
|
| 319 |
-
and hasattr(module._hf_hook, "execution_device")
|
| 320 |
-
and module._hf_hook.execution_device is not None
|
| 321 |
-
):
|
| 322 |
-
return torch.device(module._hf_hook.execution_device)
|
| 323 |
-
return self.device
|
| 324 |
-
|
| 325 |
-
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
| 326 |
-
r"""
|
| 327 |
-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
| 328 |
-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
| 329 |
-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
| 330 |
-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
| 331 |
-
|
| 332 |
-
Arguments:
|
| 333 |
-
gpu_id (`int`, *optional*):
|
| 334 |
-
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
|
| 335 |
-
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
|
| 336 |
-
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
| 337 |
-
default to "cuda".
|
| 338 |
-
"""
|
| 339 |
-
if self.model_cpu_offload_seq is None:
|
| 340 |
-
raise ValueError(
|
| 341 |
-
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
| 345 |
-
from accelerate import cpu_offload_with_hook
|
| 346 |
-
else:
|
| 347 |
-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
| 348 |
-
|
| 349 |
-
torch_device = torch.device(device)
|
| 350 |
-
device_index = torch_device.index
|
| 351 |
-
|
| 352 |
-
if gpu_id is not None and device_index is not None:
|
| 353 |
-
raise ValueError(
|
| 354 |
-
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
| 355 |
-
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of "
|
| 356 |
-
f"the device: `device`={torch_device.type}"
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`)
|
| 360 |
-
# or default to previously set id or default to 0
|
| 361 |
-
self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
|
| 362 |
-
|
| 363 |
-
device_type = torch_device.type
|
| 364 |
-
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
| 365 |
-
|
| 366 |
-
if self.device.type != "cpu":
|
| 367 |
-
self.to("cpu")
|
| 368 |
-
device_mod = getattr(torch, self.device.type, None)
|
| 369 |
-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
| 370 |
-
device_mod.empty_cache()
|
| 371 |
-
# otherwise we don't see the memory savings (but they probably exist)
|
| 372 |
-
|
| 373 |
-
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
| 374 |
-
|
| 375 |
-
self._all_hooks = []
|
| 376 |
-
hook = None
|
| 377 |
-
for model_str in self.model_cpu_offload_seq.split("->"):
|
| 378 |
-
model = all_model_components.pop(model_str, None)
|
| 379 |
-
if not isinstance(model, torch.nn.Module):
|
| 380 |
-
continue
|
| 381 |
-
|
| 382 |
-
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
| 383 |
-
self._all_hooks.append(hook)
|
| 384 |
-
|
| 385 |
-
# CPU offload models that are not in the seq chain unless they are explicitly excluded
|
| 386 |
-
# these models will stay on CPU until maybe_free_model_hooks is called
|
| 387 |
-
# some models cannot be in the seq chain because they are iteratively called,
|
| 388 |
-
# such as controlnet
|
| 389 |
-
for name, model in all_model_components.items():
|
| 390 |
-
if not isinstance(model, torch.nn.Module):
|
| 391 |
-
continue
|
| 392 |
-
|
| 393 |
-
if name in self._exclude_from_cpu_offload:
|
| 394 |
-
model.to(device)
|
| 395 |
-
else:
|
| 396 |
-
_, hook = cpu_offload_with_hook(model, device)
|
| 397 |
-
self._all_hooks.append(hook)
|
| 398 |
-
|
| 399 |
-
def maybe_free_model_hooks(self):
|
| 400 |
-
r"""
|
| 401 |
-
Function that offloads all components, removes all model hooks that were added when using
|
| 402 |
-
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
|
| 403 |
-
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
| 404 |
-
functions correctly when applying enable_model_cpu_offload.
|
| 405 |
-
"""
|
| 406 |
-
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
| 407 |
-
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
| 408 |
-
return
|
| 409 |
-
|
| 410 |
-
for hook in self._all_hooks:
|
| 411 |
-
# offload model and remove hook from model
|
| 412 |
-
hook.offload()
|
| 413 |
-
hook.remove()
|
| 414 |
-
|
| 415 |
-
# make sure the model is in the same state as before calling it
|
| 416 |
-
self.enable_model_cpu_offload()
|
| 417 |
-
|
| 418 |
-
@synchronize_timer('Encode cond')
|
| 419 |
-
def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):
|
| 420 |
-
bsz = image.shape[0]
|
| 421 |
-
cond = self.conditioner(image=image, **additional_cond_inputs) # cond['main'].shape
|
| 422 |
-
|
| 423 |
-
if do_classifier_free_guidance:
|
| 424 |
-
cond_token_num = cond["main"].shape[1]
|
| 425 |
-
additional_cond_inputs["num_tokens"] = cond_token_num
|
| 426 |
-
un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
|
| 427 |
-
|
| 428 |
-
if dual_guidance:
|
| 429 |
-
un_cond_drop_main = copy.deepcopy(un_cond)
|
| 430 |
-
un_cond_drop_main['additional'] = cond['additional']
|
| 431 |
-
|
| 432 |
-
def cat_recursive(a, b, c):
|
| 433 |
-
if isinstance(a, torch.Tensor):
|
| 434 |
-
return torch.cat([a, b, c], dim=0).to(self.dtype)
|
| 435 |
-
out = {}
|
| 436 |
-
for k in a.keys():
|
| 437 |
-
out[k] = cat_recursive(a[k], b[k], c[k])
|
| 438 |
-
return out
|
| 439 |
-
|
| 440 |
-
cond = cat_recursive(cond, un_cond_drop_main, un_cond)
|
| 441 |
-
else:
|
| 442 |
-
def cat_recursive(a, b):
|
| 443 |
-
if isinstance(a, torch.Tensor):
|
| 444 |
-
return torch.cat([a, b], dim=0).to(self.dtype)
|
| 445 |
-
out = {}
|
| 446 |
-
for k in a.keys():
|
| 447 |
-
out[k] = cat_recursive(a[k], b[k])
|
| 448 |
-
return out
|
| 449 |
-
|
| 450 |
-
cond = cat_recursive(cond, un_cond)
|
| 451 |
-
return cond
|
| 452 |
-
|
| 453 |
-
def prepare_extra_step_kwargs(self, generator, eta):
|
| 454 |
-
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 455 |
-
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 456 |
-
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 457 |
-
# and should be between [0, 1]
|
| 458 |
-
|
| 459 |
-
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 460 |
-
extra_step_kwargs = {}
|
| 461 |
-
if accepts_eta:
|
| 462 |
-
extra_step_kwargs["eta"] = eta
|
| 463 |
-
|
| 464 |
-
# check if the scheduler accepts generator
|
| 465 |
-
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 466 |
-
if accepts_generator:
|
| 467 |
-
extra_step_kwargs["generator"] = generator
|
| 468 |
-
return extra_step_kwargs
|
| 469 |
-
|
| 470 |
-
def prepare_latents(self, batch_size, dtype, device, generator, latents=None, shape=None):
|
| 471 |
-
if shape is None:
|
| 472 |
-
shape = (batch_size, *self.vae.latent_shape)
|
| 473 |
-
else:
|
| 474 |
-
shape = (batch_size, *shape)
|
| 475 |
-
if isinstance(generator, list) and len(generator) != batch_size:
|
| 476 |
-
raise ValueError(
|
| 477 |
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 478 |
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 479 |
-
)
|
| 480 |
-
|
| 481 |
-
if latents is None:
|
| 482 |
-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 483 |
-
else:
|
| 484 |
-
latents = latents.to(device)
|
| 485 |
-
|
| 486 |
-
# scale the initial noise by the standard deviation required by the scheduler
|
| 487 |
-
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
|
| 488 |
-
return latents
|
| 489 |
-
|
| 490 |
-
def prepare_image(self, image, mask=None) -> dict:
|
| 491 |
-
if isinstance(image, torch.Tensor) and isinstance(mask, torch.Tensor):
|
| 492 |
-
outputs = {
|
| 493 |
-
'image': image,
|
| 494 |
-
'mask': mask
|
| 495 |
-
}
|
| 496 |
-
return outputs
|
| 497 |
-
|
| 498 |
-
if isinstance(image, str) and not os.path.exists(image):
|
| 499 |
-
raise FileNotFoundError(f"Couldn't find image at path {image}")
|
| 500 |
-
|
| 501 |
-
if not isinstance(image, list):
|
| 502 |
-
image = [image]
|
| 503 |
-
|
| 504 |
-
outputs = []
|
| 505 |
-
for img in image:
|
| 506 |
-
output = self.image_processor(img) # output['image'].shape
|
| 507 |
-
outputs.append(output)
|
| 508 |
-
|
| 509 |
-
cond_input = {k: [] for k in outputs[0].keys()}
|
| 510 |
-
for output in outputs:
|
| 511 |
-
for key, value in output.items():
|
| 512 |
-
cond_input[key].append(value)
|
| 513 |
-
for key, value in cond_input.items():
|
| 514 |
-
if isinstance(value[0], torch.Tensor):
|
| 515 |
-
cond_input[key] = torch.cat(value, dim=0)
|
| 516 |
-
|
| 517 |
-
return cond_input
|
| 518 |
-
|
| 519 |
-
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
| 520 |
-
"""
|
| 521 |
-
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 522 |
-
|
| 523 |
-
Args:
|
| 524 |
-
timesteps (`torch.Tensor`):
|
| 525 |
-
generate embedding vectors at these timesteps
|
| 526 |
-
embedding_dim (`int`, *optional*, defaults to 512):
|
| 527 |
-
dimension of the embeddings to generate
|
| 528 |
-
dtype:
|
| 529 |
-
data type of the generated embeddings
|
| 530 |
-
|
| 531 |
-
Returns:
|
| 532 |
-
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 533 |
-
"""
|
| 534 |
-
assert len(w.shape) == 1
|
| 535 |
-
w = w * 1000.0
|
| 536 |
-
|
| 537 |
-
half_dim = embedding_dim // 2
|
| 538 |
-
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 539 |
-
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 540 |
-
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 541 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 542 |
-
if embedding_dim % 2 == 1: # zero pad
|
| 543 |
-
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 544 |
-
assert emb.shape == (w.shape[0], embedding_dim)
|
| 545 |
-
return emb
|
| 546 |
-
|
| 547 |
-
def set_surface_extractor(self, mc_algo):
|
| 548 |
-
if mc_algo is None:
|
| 549 |
-
return
|
| 550 |
-
logger.info('The parameters `mc_algo` is deprecated, and will be removed in future versions.\n'
|
| 551 |
-
'Please use: \n'
|
| 552 |
-
'from hy3dshape.models.autoencoders import SurfaceExtractors\n'
|
| 553 |
-
'pipeline.vae.surface_extractor = SurfaceExtractors[mc_algo]() instead\n')
|
| 554 |
-
if mc_algo not in SurfaceExtractors.keys():
|
| 555 |
-
raise ValueError(f"Unknown mc_algo {mc_algo}")
|
| 556 |
-
self.vae.surface_extractor = SurfaceExtractors[mc_algo]()
|
| 557 |
-
|
| 558 |
-
@torch.no_grad()
|
| 559 |
-
def __call__(
|
| 560 |
-
self,
|
| 561 |
-
image: Union[str, List[str], Image.Image] = None,
|
| 562 |
-
num_inference_steps: int = 50,
|
| 563 |
-
timesteps: List[int] = None,
|
| 564 |
-
sigmas: List[float] = None,
|
| 565 |
-
eta: float = 0.0,
|
| 566 |
-
guidance_scale: float = 7.5,
|
| 567 |
-
dual_guidance_scale: float = 10.5,
|
| 568 |
-
dual_guidance: bool = True,
|
| 569 |
-
generator=None,
|
| 570 |
-
box_v=1.01,
|
| 571 |
-
octree_resolution=384,
|
| 572 |
-
mc_level=-1 / 512,
|
| 573 |
-
num_chunks=8000,
|
| 574 |
-
mc_algo=None,
|
| 575 |
-
output_type: Optional[str] = "trimesh",
|
| 576 |
-
enable_pbar=True,
|
| 577 |
-
**kwargs,
|
| 578 |
-
) -> List[List[trimesh.Trimesh]]:
|
| 579 |
-
callback = kwargs.pop("callback", None)
|
| 580 |
-
callback_steps = kwargs.pop("callback_steps", None)
|
| 581 |
-
|
| 582 |
-
self.set_surface_extractor(mc_algo)
|
| 583 |
-
|
| 584 |
-
device = self.device
|
| 585 |
-
dtype = self.dtype
|
| 586 |
-
do_classifier_free_guidance = guidance_scale >= 0 and \
|
| 587 |
-
getattr(self.model, 'guidance_cond_proj_dim', None) is None
|
| 588 |
-
dual_guidance = dual_guidance_scale >= 0 and dual_guidance
|
| 589 |
-
|
| 590 |
-
if isinstance(image, torch.Tensor):
|
| 591 |
-
pass
|
| 592 |
-
else:
|
| 593 |
-
cond_inputs = self.prepare_image(image)
|
| 594 |
-
image = cond_inputs.pop('image')
|
| 595 |
-
|
| 596 |
-
cond = self.encode_cond(
|
| 597 |
-
image=image,
|
| 598 |
-
additional_cond_inputs=cond_inputs,
|
| 599 |
-
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 600 |
-
dual_guidance=False,
|
| 601 |
-
)
|
| 602 |
-
batch_size = image.shape[0]
|
| 603 |
-
|
| 604 |
-
t_dtype = torch.long
|
| 605 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 606 |
-
self.scheduler, num_inference_steps, device, timesteps, sigmas)
|
| 607 |
-
|
| 608 |
-
latents = self.prepare_latents(batch_size, dtype, device, generator)
|
| 609 |
-
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 610 |
-
|
| 611 |
-
guidance_cond = None
|
| 612 |
-
if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
|
| 613 |
-
logger.info('Using lcm guidance scale')
|
| 614 |
-
guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
|
| 615 |
-
guidance_cond = self.get_guidance_scale_embedding(
|
| 616 |
-
guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
|
| 617 |
-
).to(device=device, dtype=latents.dtype)
|
| 618 |
-
with synchronize_timer('Diffusion Sampling'):
|
| 619 |
-
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
|
| 620 |
-
# expand the latents if we are doing classifier free guidance
|
| 621 |
-
if do_classifier_free_guidance:
|
| 622 |
-
latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
|
| 623 |
-
else:
|
| 624 |
-
latent_model_input = latents
|
| 625 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 626 |
-
|
| 627 |
-
# predict the noise residual
|
| 628 |
-
timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
|
| 629 |
-
timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
|
| 630 |
-
noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
|
| 631 |
-
|
| 632 |
-
# no drop, drop clip, all drop
|
| 633 |
-
if do_classifier_free_guidance:
|
| 634 |
-
if dual_guidance:
|
| 635 |
-
noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
|
| 636 |
-
noise_pred = (
|
| 637 |
-
noise_pred_uncond
|
| 638 |
-
+ guidance_scale * (noise_pred_clip - noise_pred_dino)
|
| 639 |
-
+ dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
|
| 640 |
-
)
|
| 641 |
-
else:
|
| 642 |
-
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
| 643 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 644 |
-
|
| 645 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 646 |
-
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 647 |
-
latents = outputs.prev_sample
|
| 648 |
-
|
| 649 |
-
if callback is not None and i % callback_steps == 0:
|
| 650 |
-
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 651 |
-
callback(step_idx, t, outputs)
|
| 652 |
-
|
| 653 |
-
return self._export(
|
| 654 |
-
latents,
|
| 655 |
-
output_type,
|
| 656 |
-
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
| 657 |
-
)
|
| 658 |
-
|
| 659 |
-
def _export(
|
| 660 |
-
self,
|
| 661 |
-
latents,
|
| 662 |
-
output_type='trimesh',
|
| 663 |
-
box_v=1.01,
|
| 664 |
-
mc_level=0.0,
|
| 665 |
-
num_chunks=20000,
|
| 666 |
-
octree_resolution=256,
|
| 667 |
-
mc_algo='mc',
|
| 668 |
-
enable_pbar=True
|
| 669 |
-
):
|
| 670 |
-
if not output_type == "latent":
|
| 671 |
-
latents = 1. / self.vae.scale_factor * latents
|
| 672 |
-
latents = self.vae(latents)
|
| 673 |
-
outputs, _ = self.vae.latents2mesh(
|
| 674 |
-
latents,
|
| 675 |
-
bounds=box_v,
|
| 676 |
-
mc_level=mc_level,
|
| 677 |
-
num_chunks=num_chunks,
|
| 678 |
-
octree_resolution=octree_resolution,
|
| 679 |
-
mc_algo=mc_algo,
|
| 680 |
-
enable_pbar=enable_pbar,
|
| 681 |
-
)
|
| 682 |
-
else:
|
| 683 |
-
outputs = latents
|
| 684 |
-
|
| 685 |
-
if output_type == 'trimesh':
|
| 686 |
-
outputs = export_to_trimesh(outputs)
|
| 687 |
-
|
| 688 |
-
return outputs
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
class UltraShapePipeline(DiTPipeline):
|
| 692 |
-
|
| 693 |
-
@torch.inference_mode()
|
| 694 |
-
def __call__(
|
| 695 |
-
self,
|
| 696 |
-
image: Union[str, List[str], Image.Image, dict, List[dict], torch.Tensor] = None,
|
| 697 |
-
voxel_cond: torch.Tensor = None,
|
| 698 |
-
num_inference_steps: int = 50,
|
| 699 |
-
timesteps: List[int] = None,
|
| 700 |
-
sigmas: List[float] = None,
|
| 701 |
-
eta: float = 0.0,
|
| 702 |
-
guidance_scale: float = 5.0,
|
| 703 |
-
generator=None,
|
| 704 |
-
box_v=1.01,
|
| 705 |
-
octree_resolution=384,
|
| 706 |
-
mc_level=0.0,
|
| 707 |
-
mc_algo=None,
|
| 708 |
-
num_chunks=8000,
|
| 709 |
-
output_type: Optional[str] = "trimesh",
|
| 710 |
-
enable_pbar=True,
|
| 711 |
-
mask = None,
|
| 712 |
-
**kwargs,
|
| 713 |
-
) -> List[List[trimesh.Trimesh]]:
|
| 714 |
-
callback = kwargs.pop("callback", None)
|
| 715 |
-
callback_steps = kwargs.pop("callback_steps", None)
|
| 716 |
-
|
| 717 |
-
self.set_surface_extractor(mc_algo)
|
| 718 |
-
|
| 719 |
-
device = self.device
|
| 720 |
-
dtype = self.dtype
|
| 721 |
-
do_classifier_free_guidance = guidance_scale >= 0 and not (
|
| 722 |
-
hasattr(self.model, 'guidance_embed') and
|
| 723 |
-
self.model.guidance_embed is True
|
| 724 |
-
)
|
| 725 |
-
|
| 726 |
-
# print('image', type(image), 'mask', type(mask))
|
| 727 |
-
cond_inputs = self.prepare_image(image, mask)
|
| 728 |
-
image = cond_inputs.pop('image')
|
| 729 |
-
cond = self.encode_cond(
|
| 730 |
-
image=image,
|
| 731 |
-
additional_cond_inputs=cond_inputs,
|
| 732 |
-
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 733 |
-
dual_guidance=False,
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
batch_size = image.shape[0]
|
| 737 |
-
|
| 738 |
-
# 5. Prepare timesteps
|
| 739 |
-
# NOTE: this is slightly different from common usage, we start from 0.
|
| 740 |
-
sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
|
| 741 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 742 |
-
self.scheduler,
|
| 743 |
-
num_inference_steps,
|
| 744 |
-
device,
|
| 745 |
-
sigmas=sigmas,
|
| 746 |
-
)
|
| 747 |
-
latents_shape = None
|
| 748 |
-
if voxel_cond is not None:
|
| 749 |
-
# voxel_cond: [B, N, 3] -> [N, 3] if batched? No, it's [B, N, 3] usually
|
| 750 |
-
# The encoder expects [B, N, 3]
|
| 751 |
-
num_tokens = voxel_cond.shape[1]
|
| 752 |
-
latents_shape = (num_tokens, self.vae.latent_shape[-1])
|
| 753 |
-
|
| 754 |
-
latents = self.prepare_latents(batch_size, dtype, device, generator, shape=latents_shape)
|
| 755 |
-
|
| 756 |
-
guidance = None
|
| 757 |
-
if hasattr(self.model, 'guidance_embed') and \
|
| 758 |
-
self.model.guidance_embed is True:
|
| 759 |
-
guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)
|
| 760 |
-
# logger.info(f'Using guidance embed with scale {guidance_scale}')
|
| 761 |
-
if do_classifier_free_guidance and voxel_cond is not None:
|
| 762 |
-
voxel_cond = torch.cat([voxel_cond] * 2)
|
| 763 |
-
with synchronize_timer('Diffusion Sampling'):
|
| 764 |
-
for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")):
|
| 765 |
-
# expand the latents if we are doing classifier free guidance
|
| 766 |
-
if do_classifier_free_guidance:
|
| 767 |
-
latent_model_input = torch.cat([latents] * 2)
|
| 768 |
-
else:
|
| 769 |
-
latent_model_input = latents
|
| 770 |
-
|
| 771 |
-
# NOTE: we assume model get timesteps ranged from 0 to 1
|
| 772 |
-
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype).to(latent_model_input.device)
|
| 773 |
-
timestep = timestep / self.scheduler.config.num_train_timesteps
|
| 774 |
-
if voxel_cond is None:
|
| 775 |
-
noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
|
| 776 |
-
else:
|
| 777 |
-
noise_pred = self.model(latent_model_input, timestep, cond,
|
| 778 |
-
guidance=guidance, voxel_cond=voxel_cond)
|
| 779 |
-
|
| 780 |
-
if do_classifier_free_guidance:
|
| 781 |
-
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
| 782 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 783 |
-
|
| 784 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 785 |
-
outputs = self.scheduler.step(noise_pred, t, latents)
|
| 786 |
-
latents = outputs.prev_sample
|
| 787 |
-
|
| 788 |
-
if callback is not None and i % callback_steps == 0:
|
| 789 |
-
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 790 |
-
callback(step_idx, t, outputs)
|
| 791 |
-
|
| 792 |
-
return self._export(
|
| 793 |
-
latents,
|
| 794 |
-
output_type,
|
| 795 |
-
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
| 796 |
-
enable_pbar=enable_pbar,
|
| 797 |
-
), latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/postprocessors.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
import os
|
| 23 |
-
import tempfile
|
| 24 |
-
from typing import Union
|
| 25 |
-
|
| 26 |
-
import numpy as np
|
| 27 |
-
import pymeshlab
|
| 28 |
-
import torch
|
| 29 |
-
import trimesh
|
| 30 |
-
|
| 31 |
-
from .models.autoencoders import Latent2MeshOutput
|
| 32 |
-
from .utils import synchronize_timer
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def load_mesh(path):
|
| 36 |
-
if path.endswith(".glb"):
|
| 37 |
-
mesh = trimesh.load(path)
|
| 38 |
-
else:
|
| 39 |
-
mesh = pymeshlab.MeshSet()
|
| 40 |
-
mesh.load_new_mesh(path)
|
| 41 |
-
return mesh
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def reduce_face(mesh: pymeshlab.MeshSet, max_facenum: int = 200000):
|
| 45 |
-
if max_facenum > mesh.current_mesh().face_number():
|
| 46 |
-
return mesh
|
| 47 |
-
|
| 48 |
-
mesh.apply_filter(
|
| 49 |
-
"meshing_decimation_quadric_edge_collapse",
|
| 50 |
-
targetfacenum=max_facenum,
|
| 51 |
-
qualitythr=1.0,
|
| 52 |
-
preserveboundary=True,
|
| 53 |
-
boundaryweight=3,
|
| 54 |
-
preservenormal=True,
|
| 55 |
-
preservetopology=True,
|
| 56 |
-
autoclean=True
|
| 57 |
-
)
|
| 58 |
-
return mesh
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def remove_floater(mesh: pymeshlab.MeshSet):
|
| 62 |
-
mesh.apply_filter("compute_selection_by_small_disconnected_components_per_face",
|
| 63 |
-
nbfaceratio=0.005)
|
| 64 |
-
mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False)
|
| 65 |
-
mesh.apply_filter("meshing_remove_selected_vertices_and_faces")
|
| 66 |
-
return mesh
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def pymeshlab2trimesh(mesh: pymeshlab.MeshSet):
|
| 70 |
-
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
| 71 |
-
mesh.save_current_mesh(temp_file.name)
|
| 72 |
-
mesh = trimesh.load(temp_file.name)
|
| 73 |
-
# 检查加载的对象类型
|
| 74 |
-
if isinstance(mesh, trimesh.Scene):
|
| 75 |
-
combined_mesh = trimesh.Trimesh()
|
| 76 |
-
# 如果是Scene,遍历所有的geometry并合并
|
| 77 |
-
for geom in mesh.geometry.values():
|
| 78 |
-
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
| 79 |
-
mesh = combined_mesh
|
| 80 |
-
return mesh
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def trimesh2pymeshlab(mesh: trimesh.Trimesh):
|
| 84 |
-
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
| 85 |
-
if isinstance(mesh, trimesh.scene.Scene):
|
| 86 |
-
for idx, obj in enumerate(mesh.geometry.values()):
|
| 87 |
-
if idx == 0:
|
| 88 |
-
temp_mesh = obj
|
| 89 |
-
else:
|
| 90 |
-
temp_mesh = temp_mesh + obj
|
| 91 |
-
mesh = temp_mesh
|
| 92 |
-
mesh.export(temp_file.name)
|
| 93 |
-
mesh = pymeshlab.MeshSet()
|
| 94 |
-
mesh.load_new_mesh(temp_file.name)
|
| 95 |
-
return mesh
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def export_mesh(input, output):
|
| 99 |
-
if isinstance(input, pymeshlab.MeshSet):
|
| 100 |
-
mesh = output
|
| 101 |
-
elif isinstance(input, Latent2MeshOutput):
|
| 102 |
-
output = Latent2MeshOutput()
|
| 103 |
-
output.mesh_v = output.current_mesh().vertex_matrix()
|
| 104 |
-
output.mesh_f = output.current_mesh().face_matrix()
|
| 105 |
-
mesh = output
|
| 106 |
-
else:
|
| 107 |
-
mesh = pymeshlab2trimesh(output)
|
| 108 |
-
return mesh
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def import_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
|
| 112 |
-
if isinstance(mesh, str):
|
| 113 |
-
mesh = load_mesh(mesh)
|
| 114 |
-
elif isinstance(mesh, Latent2MeshOutput):
|
| 115 |
-
mesh = pymeshlab.MeshSet()
|
| 116 |
-
mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
|
| 117 |
-
mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
|
| 118 |
-
|
| 119 |
-
if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
|
| 120 |
-
mesh = trimesh2pymeshlab(mesh)
|
| 121 |
-
|
| 122 |
-
return mesh
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class FaceReducer:
|
| 126 |
-
@synchronize_timer('FaceReducer')
|
| 127 |
-
def __call__(
|
| 128 |
-
self,
|
| 129 |
-
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
| 130 |
-
max_facenum: int = 40000
|
| 131 |
-
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh]:
|
| 132 |
-
ms = import_mesh(mesh)
|
| 133 |
-
ms = reduce_face(ms, max_facenum=max_facenum)
|
| 134 |
-
mesh = export_mesh(mesh, ms)
|
| 135 |
-
return mesh
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class FloaterRemover:
|
| 139 |
-
@synchronize_timer('FloaterRemover')
|
| 140 |
-
def __call__(
|
| 141 |
-
self,
|
| 142 |
-
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
| 143 |
-
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
| 144 |
-
ms = import_mesh(mesh)
|
| 145 |
-
ms = remove_floater(ms)
|
| 146 |
-
mesh = export_mesh(mesh, ms)
|
| 147 |
-
return mesh
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
class DegenerateFaceRemover:
|
| 151 |
-
@synchronize_timer('DegenerateFaceRemover')
|
| 152 |
-
def __call__(
|
| 153 |
-
self,
|
| 154 |
-
mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str],
|
| 155 |
-
) -> Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput]:
|
| 156 |
-
ms = import_mesh(mesh)
|
| 157 |
-
|
| 158 |
-
with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as temp_file:
|
| 159 |
-
ms.save_current_mesh(temp_file.name)
|
| 160 |
-
ms = pymeshlab.MeshSet()
|
| 161 |
-
ms.load_new_mesh(temp_file.name)
|
| 162 |
-
|
| 163 |
-
mesh = export_mesh(mesh, ms)
|
| 164 |
-
return mesh
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def mesh_normalize(mesh):
|
| 168 |
-
"""
|
| 169 |
-
Normalize mesh vertices to sphere
|
| 170 |
-
"""
|
| 171 |
-
scale_factor = 1.2
|
| 172 |
-
vtx_pos = np.asarray(mesh.vertices)
|
| 173 |
-
max_bb = (vtx_pos - 0).max(0)[0]
|
| 174 |
-
min_bb = (vtx_pos - 0).min(0)[0]
|
| 175 |
-
|
| 176 |
-
center = (max_bb + min_bb) / 2
|
| 177 |
-
|
| 178 |
-
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
|
| 179 |
-
|
| 180 |
-
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
|
| 181 |
-
mesh.vertices = vtx_pos
|
| 182 |
-
|
| 183 |
-
return mesh
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class MeshSimplifier:
|
| 187 |
-
def __init__(self, executable: str = None):
|
| 188 |
-
if executable is None:
|
| 189 |
-
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 190 |
-
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
|
| 191 |
-
self.executable = executable
|
| 192 |
-
|
| 193 |
-
@synchronize_timer('MeshSimplifier')
|
| 194 |
-
def __call__(
|
| 195 |
-
self,
|
| 196 |
-
mesh: Union[trimesh.Trimesh],
|
| 197 |
-
) -> Union[trimesh.Trimesh]:
|
| 198 |
-
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_input:
|
| 199 |
-
with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as temp_output:
|
| 200 |
-
mesh.export(temp_input.name)
|
| 201 |
-
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
|
| 202 |
-
ms = trimesh.load(temp_output.name, process=False)
|
| 203 |
-
if isinstance(ms, trimesh.Scene):
|
| 204 |
-
combined_mesh = trimesh.Trimesh()
|
| 205 |
-
for geom in ms.geometry.values():
|
| 206 |
-
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
| 207 |
-
ms = combined_mesh
|
| 208 |
-
ms = mesh_normalize(ms)
|
| 209 |
-
return ms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/preprocessors.py
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import cv2
|
| 16 |
-
import numpy as np
|
| 17 |
-
import torch
|
| 18 |
-
from PIL import Image
|
| 19 |
-
from einops import repeat, rearrange
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def array_to_tensor(np_array):
|
| 23 |
-
image_pt = torch.tensor(np_array).float()
|
| 24 |
-
image_pt = image_pt / 255 * 2 - 1
|
| 25 |
-
image_pt = rearrange(image_pt, "h w c -> c h w")
|
| 26 |
-
image_pts = repeat(image_pt, "c h w -> b c h w", b=1)
|
| 27 |
-
return image_pts
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class ImageProcessorV2:
|
| 31 |
-
def __init__(self, size=512, border_ratio=None):
|
| 32 |
-
self.size = size
|
| 33 |
-
self.border_ratio = border_ratio
|
| 34 |
-
|
| 35 |
-
@staticmethod
|
| 36 |
-
def recenter(image, border_ratio: float = 0.2):
|
| 37 |
-
""" recenter an image to leave some empty space at the image border.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
image (ndarray): input image, float/uint8 [H, W, 3/4]
|
| 41 |
-
mask (ndarray): alpha mask, bool [H, W]
|
| 42 |
-
border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.2.
|
| 43 |
-
|
| 44 |
-
Returns:
|
| 45 |
-
ndarray: output image, float/uint8 [H, W, 3/4]
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
if image.shape[-1] == 4:
|
| 49 |
-
mask = image[..., 3]
|
| 50 |
-
else:
|
| 51 |
-
mask = np.ones_like(image[..., 0:1]) * 255
|
| 52 |
-
image = np.concatenate([image, mask], axis=-1)
|
| 53 |
-
mask = mask[..., 0]
|
| 54 |
-
|
| 55 |
-
H, W, C = image.shape
|
| 56 |
-
|
| 57 |
-
size = max(H, W)
|
| 58 |
-
result = np.zeros((size, size, C), dtype=np.uint8)
|
| 59 |
-
|
| 60 |
-
coords = np.nonzero(mask)
|
| 61 |
-
x_min, x_max = coords[0].min(), coords[0].max()
|
| 62 |
-
y_min, y_max = coords[1].min(), coords[1].max()
|
| 63 |
-
h = x_max - x_min
|
| 64 |
-
w = y_max - y_min
|
| 65 |
-
if h == 0 or w == 0:
|
| 66 |
-
raise ValueError('input image is empty')
|
| 67 |
-
desired_size = int(size * (1 - border_ratio))
|
| 68 |
-
scale = desired_size / max(h, w)
|
| 69 |
-
h2 = int(h * scale)
|
| 70 |
-
w2 = int(w * scale)
|
| 71 |
-
x2_min = (size - h2) // 2
|
| 72 |
-
x2_max = x2_min + h2
|
| 73 |
-
|
| 74 |
-
y2_min = (size - w2) // 2
|
| 75 |
-
y2_max = y2_min + w2
|
| 76 |
-
|
| 77 |
-
result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(image[x_min:x_max, y_min:y_max], (w2, h2),
|
| 78 |
-
interpolation=cv2.INTER_AREA)
|
| 79 |
-
|
| 80 |
-
bg = np.ones((result.shape[0], result.shape[1], 3), dtype=np.uint8) * 255
|
| 81 |
-
|
| 82 |
-
mask = result[..., 3:].astype(np.float32) / 255
|
| 83 |
-
result = result[..., :3] * mask + bg * (1 - mask)
|
| 84 |
-
|
| 85 |
-
mask = mask * 255
|
| 86 |
-
result = result.clip(0, 255).astype(np.uint8)
|
| 87 |
-
mask = mask.clip(0, 255).astype(np.uint8)
|
| 88 |
-
return result, mask
|
| 89 |
-
|
| 90 |
-
def load_image(self, image, border_ratio=0.15, to_tensor=True):
|
| 91 |
-
if isinstance(image, str):
|
| 92 |
-
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
| 93 |
-
image, mask = self.recenter(image, border_ratio=border_ratio)
|
| 94 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 95 |
-
elif isinstance(image, Image.Image):
|
| 96 |
-
image = image.convert("RGBA")
|
| 97 |
-
image = np.asarray(image)
|
| 98 |
-
image, mask = self.recenter(image, border_ratio=border_ratio)
|
| 99 |
-
|
| 100 |
-
image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
| 101 |
-
mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
|
| 102 |
-
mask = mask[..., np.newaxis]
|
| 103 |
-
|
| 104 |
-
if to_tensor:
|
| 105 |
-
image = array_to_tensor(image)
|
| 106 |
-
mask = array_to_tensor(mask)
|
| 107 |
-
return image, mask
|
| 108 |
-
|
| 109 |
-
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
|
| 110 |
-
if self.border_ratio is not None:
|
| 111 |
-
border_ratio = self.border_ratio
|
| 112 |
-
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
| 113 |
-
outputs = {
|
| 114 |
-
'image': image,
|
| 115 |
-
'mask': mask
|
| 116 |
-
}
|
| 117 |
-
return outputs
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
class MVImageProcessorV2(ImageProcessorV2):
|
| 121 |
-
"""
|
| 122 |
-
view order: front, front clockwise 90, back, front clockwise 270
|
| 123 |
-
"""
|
| 124 |
-
return_view_idx = True
|
| 125 |
-
|
| 126 |
-
def __init__(self, size=512, border_ratio=None):
|
| 127 |
-
super().__init__(size, border_ratio)
|
| 128 |
-
self.view2idx = {
|
| 129 |
-
'front': 0,
|
| 130 |
-
'left': 1,
|
| 131 |
-
'back': 2,
|
| 132 |
-
'right': 3
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
|
| 136 |
-
if self.border_ratio is not None:
|
| 137 |
-
border_ratio = self.border_ratio
|
| 138 |
-
|
| 139 |
-
images = []
|
| 140 |
-
masks = []
|
| 141 |
-
view_idxs = []
|
| 142 |
-
for idx, (view_tag, image) in enumerate(image_dict.items()):
|
| 143 |
-
view_idxs.append(self.view2idx[view_tag])
|
| 144 |
-
image, mask = self.load_image(image, border_ratio=border_ratio, to_tensor=to_tensor)
|
| 145 |
-
images.append(image)
|
| 146 |
-
masks.append(mask)
|
| 147 |
-
|
| 148 |
-
zipped_lists = zip(view_idxs, images, masks)
|
| 149 |
-
sorted_zipped_lists = sorted(zipped_lists)
|
| 150 |
-
view_idxs, images, masks = zip(*sorted_zipped_lists)
|
| 151 |
-
|
| 152 |
-
image = torch.cat(images, 0).unsqueeze(0)
|
| 153 |
-
mask = torch.cat(masks, 0).unsqueeze(0)
|
| 154 |
-
outputs = {
|
| 155 |
-
'image': image,
|
| 156 |
-
'mask': mask,
|
| 157 |
-
'view_idxs': view_idxs
|
| 158 |
-
}
|
| 159 |
-
return outputs
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
IMAGE_PROCESSORS = {
|
| 163 |
-
"v2": ImageProcessorV2,
|
| 164 |
-
'mv_v2': MVImageProcessorV2,
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
DEFAULT_IMAGEPROCESSOR = 'v2'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/rembg.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
# ==============================================================================
|
| 2 |
-
# Original work Copyright (c) 2025 Tencent.
|
| 3 |
-
# Modified work Copyright (c) 2025 UltraShape Team.
|
| 4 |
-
#
|
| 5 |
-
# Modified by UltraShape on 2025.12.25
|
| 6 |
-
# ==============================================================================
|
| 7 |
-
|
| 8 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 9 |
-
# except for the third-party components listed below.
|
| 10 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 11 |
-
# in the repsective licenses of these third-party components.
|
| 12 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 13 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 14 |
-
# all relevant laws and regulations.
|
| 15 |
-
|
| 16 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 17 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 18 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 19 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 20 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 21 |
-
|
| 22 |
-
from PIL import Image
|
| 23 |
-
from rembg import remove, new_session
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class BackgroundRemover():
|
| 27 |
-
def __init__(self):
|
| 28 |
-
self.session = new_session()
|
| 29 |
-
|
| 30 |
-
def __call__(self, image: Image.Image):
|
| 31 |
-
output = remove(image, session=self.session, bgcolor=[255, 255, 255, 0])
|
| 32 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/schedulers.py
DELETED
|
@@ -1,480 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 16 |
-
# except for the third-party components listed below.
|
| 17 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 18 |
-
# in the repsective licenses of these third-party components.
|
| 19 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 20 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 21 |
-
# all relevant laws and regulations.
|
| 22 |
-
|
| 23 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 24 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 25 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 26 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 27 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 28 |
-
|
| 29 |
-
import math
|
| 30 |
-
from dataclasses import dataclass
|
| 31 |
-
from typing import List, Optional, Tuple, Union
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
import torch
|
| 35 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 36 |
-
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 37 |
-
from diffusers.utils import BaseOutput, logging
|
| 38 |
-
|
| 39 |
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@dataclass
|
| 43 |
-
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 44 |
-
"""
|
| 45 |
-
Output class for the scheduler's `step` function output.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 49 |
-
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 50 |
-
denoising loop.
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
prev_sample: torch.FloatTensor
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 57 |
-
"""
|
| 58 |
-
NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
|
| 59 |
-
|
| 60 |
-
Euler scheduler.
|
| 61 |
-
|
| 62 |
-
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 63 |
-
methods the library implements for all schedulers such as loading and saving.
|
| 64 |
-
|
| 65 |
-
Args:
|
| 66 |
-
num_train_timesteps (`int`, defaults to 1000):
|
| 67 |
-
The number of diffusion steps to train the model.
|
| 68 |
-
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 69 |
-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 70 |
-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 71 |
-
shift (`float`, defaults to 1.0):
|
| 72 |
-
The shift value for the timestep schedule.
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
_compatibles = []
|
| 76 |
-
order = 1
|
| 77 |
-
|
| 78 |
-
@register_to_config
|
| 79 |
-
def __init__(
|
| 80 |
-
self,
|
| 81 |
-
num_train_timesteps: int = 1000,
|
| 82 |
-
shift: float = 1.0,
|
| 83 |
-
use_dynamic_shifting=False,
|
| 84 |
-
):
|
| 85 |
-
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()
|
| 86 |
-
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 87 |
-
|
| 88 |
-
sigmas = timesteps / num_train_timesteps
|
| 89 |
-
if not use_dynamic_shifting:
|
| 90 |
-
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 91 |
-
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 92 |
-
|
| 93 |
-
self.timesteps = sigmas * num_train_timesteps
|
| 94 |
-
|
| 95 |
-
self._step_index = None
|
| 96 |
-
self._begin_index = None
|
| 97 |
-
|
| 98 |
-
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 99 |
-
self.sigma_min = self.sigmas[-1].item()
|
| 100 |
-
self.sigma_max = self.sigmas[0].item()
|
| 101 |
-
|
| 102 |
-
@property
|
| 103 |
-
def step_index(self):
|
| 104 |
-
"""
|
| 105 |
-
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 106 |
-
"""
|
| 107 |
-
return self._step_index
|
| 108 |
-
|
| 109 |
-
@property
|
| 110 |
-
def begin_index(self):
|
| 111 |
-
"""
|
| 112 |
-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 113 |
-
"""
|
| 114 |
-
return self._begin_index
|
| 115 |
-
|
| 116 |
-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 117 |
-
def set_begin_index(self, begin_index: int = 0):
|
| 118 |
-
"""
|
| 119 |
-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
begin_index (`int`):
|
| 123 |
-
The begin index for the scheduler.
|
| 124 |
-
"""
|
| 125 |
-
self._begin_index = begin_index
|
| 126 |
-
|
| 127 |
-
def scale_noise(
|
| 128 |
-
self,
|
| 129 |
-
sample: torch.FloatTensor,
|
| 130 |
-
timestep: Union[float, torch.FloatTensor],
|
| 131 |
-
noise: Optional[torch.FloatTensor] = None,
|
| 132 |
-
) -> torch.FloatTensor:
|
| 133 |
-
"""
|
| 134 |
-
Forward process in flow-matching
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
sample (`torch.FloatTensor`):
|
| 138 |
-
The input sample.
|
| 139 |
-
timestep (`int`, *optional*):
|
| 140 |
-
The current timestep in the diffusion chain.
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
`torch.FloatTensor`:
|
| 144 |
-
A scaled input sample.
|
| 145 |
-
"""
|
| 146 |
-
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 147 |
-
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
| 148 |
-
|
| 149 |
-
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
| 150 |
-
# mps does not support float64
|
| 151 |
-
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
| 152 |
-
timestep = timestep.to(sample.device, dtype=torch.float32)
|
| 153 |
-
else:
|
| 154 |
-
schedule_timesteps = self.timesteps.to(sample.device)
|
| 155 |
-
timestep = timestep.to(sample.device)
|
| 156 |
-
|
| 157 |
-
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 158 |
-
if self.begin_index is None:
|
| 159 |
-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
| 160 |
-
elif self.step_index is not None:
|
| 161 |
-
# add_noise is called after first denoising step (for inpainting)
|
| 162 |
-
step_indices = [self.step_index] * timestep.shape[0]
|
| 163 |
-
else:
|
| 164 |
-
# add noise is called before first denoising step to create initial latent(img2img)
|
| 165 |
-
step_indices = [self.begin_index] * timestep.shape[0]
|
| 166 |
-
|
| 167 |
-
sigma = sigmas[step_indices].flatten()
|
| 168 |
-
while len(sigma.shape) < len(sample.shape):
|
| 169 |
-
sigma = sigma.unsqueeze(-1)
|
| 170 |
-
|
| 171 |
-
sample = sigma * noise + (1.0 - sigma) * sample
|
| 172 |
-
|
| 173 |
-
return sample
|
| 174 |
-
|
| 175 |
-
def _sigma_to_t(self, sigma):
|
| 176 |
-
return sigma * self.config.num_train_timesteps
|
| 177 |
-
|
| 178 |
-
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 179 |
-
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 180 |
-
|
| 181 |
-
def set_timesteps(
|
| 182 |
-
self,
|
| 183 |
-
num_inference_steps: int = None,
|
| 184 |
-
device: Union[str, torch.device] = None,
|
| 185 |
-
sigmas: Optional[List[float]] = None,
|
| 186 |
-
mu: Optional[float] = None,
|
| 187 |
-
):
|
| 188 |
-
"""
|
| 189 |
-
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
num_inference_steps (`int`):
|
| 193 |
-
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 194 |
-
device (`str` or `torch.device`, *optional*):
|
| 195 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 196 |
-
"""
|
| 197 |
-
|
| 198 |
-
if self.config.use_dynamic_shifting and mu is None:
|
| 199 |
-
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
| 200 |
-
|
| 201 |
-
if sigmas is None:
|
| 202 |
-
self.num_inference_steps = num_inference_steps
|
| 203 |
-
timesteps = np.linspace(
|
| 204 |
-
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
sigmas = timesteps / self.config.num_train_timesteps
|
| 208 |
-
|
| 209 |
-
if self.config.use_dynamic_shifting:
|
| 210 |
-
sigmas = self.time_shift(mu, 1.0, sigmas)
|
| 211 |
-
else:
|
| 212 |
-
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
| 213 |
-
|
| 214 |
-
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
| 215 |
-
timesteps = sigmas * self.config.num_train_timesteps
|
| 216 |
-
|
| 217 |
-
self.timesteps = timesteps.to(device=device)
|
| 218 |
-
self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
| 219 |
-
|
| 220 |
-
self._step_index = None
|
| 221 |
-
self._begin_index = None
|
| 222 |
-
|
| 223 |
-
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 224 |
-
if schedule_timesteps is None:
|
| 225 |
-
schedule_timesteps = self.timesteps
|
| 226 |
-
|
| 227 |
-
indices = (schedule_timesteps == timestep).nonzero()
|
| 228 |
-
|
| 229 |
-
# The sigma index that is taken for the **very** first `step`
|
| 230 |
-
# is always the second index (or the last index if there is only 1)
|
| 231 |
-
# This way we can ensure we don't accidentally skip a sigma in
|
| 232 |
-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 233 |
-
pos = 1 if len(indices) > 1 else 0
|
| 234 |
-
|
| 235 |
-
return indices[pos].item()
|
| 236 |
-
|
| 237 |
-
def _init_step_index(self, timestep):
|
| 238 |
-
if self.begin_index is None:
|
| 239 |
-
if isinstance(timestep, torch.Tensor):
|
| 240 |
-
timestep = timestep.to(self.timesteps.device)
|
| 241 |
-
self._step_index = self.index_for_timestep(timestep)
|
| 242 |
-
else:
|
| 243 |
-
self._step_index = self._begin_index
|
| 244 |
-
|
| 245 |
-
def step(
|
| 246 |
-
self,
|
| 247 |
-
model_output: torch.FloatTensor,
|
| 248 |
-
timestep: Union[float, torch.FloatTensor],
|
| 249 |
-
sample: torch.FloatTensor,
|
| 250 |
-
s_churn: float = 0.0,
|
| 251 |
-
s_tmin: float = 0.0,
|
| 252 |
-
s_tmax: float = float("inf"),
|
| 253 |
-
s_noise: float = 1.0,
|
| 254 |
-
generator: Optional[torch.Generator] = None,
|
| 255 |
-
return_dict: bool = True,
|
| 256 |
-
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 257 |
-
"""
|
| 258 |
-
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 259 |
-
process from the learned model outputs (most often the predicted noise).
|
| 260 |
-
|
| 261 |
-
Args:
|
| 262 |
-
model_output (`torch.FloatTensor`):
|
| 263 |
-
The direct output from learned diffusion model.
|
| 264 |
-
timestep (`float`):
|
| 265 |
-
The current discrete timestep in the diffusion chain.
|
| 266 |
-
sample (`torch.FloatTensor`):
|
| 267 |
-
A current instance of a sample created by the diffusion process.
|
| 268 |
-
s_churn (`float`):
|
| 269 |
-
s_tmin (`float`):
|
| 270 |
-
s_tmax (`float`):
|
| 271 |
-
s_noise (`float`, defaults to 1.0):
|
| 272 |
-
Scaling factor for noise added to the sample.
|
| 273 |
-
generator (`torch.Generator`, *optional*):
|
| 274 |
-
A random number generator.
|
| 275 |
-
return_dict (`bool`):
|
| 276 |
-
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 277 |
-
tuple.
|
| 278 |
-
|
| 279 |
-
Returns:
|
| 280 |
-
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 281 |
-
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 282 |
-
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 283 |
-
"""
|
| 284 |
-
|
| 285 |
-
if (
|
| 286 |
-
isinstance(timestep, int)
|
| 287 |
-
or isinstance(timestep, torch.IntTensor)
|
| 288 |
-
or isinstance(timestep, torch.LongTensor)
|
| 289 |
-
):
|
| 290 |
-
raise ValueError(
|
| 291 |
-
(
|
| 292 |
-
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 293 |
-
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 294 |
-
" one of the `scheduler.timesteps` as a timestep."
|
| 295 |
-
),
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
if self.step_index is None:
|
| 299 |
-
self._init_step_index(timestep)
|
| 300 |
-
|
| 301 |
-
# Upcast to avoid precision issues when computing prev_sample
|
| 302 |
-
sample = sample.to(torch.float32).to(model_output.device)
|
| 303 |
-
|
| 304 |
-
sigma = self.sigmas[self.step_index].to(model_output.device)
|
| 305 |
-
sigma_next = self.sigmas[self.step_index + 1].to(model_output.device)
|
| 306 |
-
|
| 307 |
-
prev_sample = sample + (sigma_next - sigma) * model_output
|
| 308 |
-
|
| 309 |
-
# Cast sample back to model compatible dtype
|
| 310 |
-
prev_sample = prev_sample.to(model_output.dtype)
|
| 311 |
-
|
| 312 |
-
# upon completion increase step index by one
|
| 313 |
-
self._step_index += 1
|
| 314 |
-
|
| 315 |
-
if not return_dict:
|
| 316 |
-
return (prev_sample,)
|
| 317 |
-
|
| 318 |
-
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 319 |
-
|
| 320 |
-
def __len__(self):
|
| 321 |
-
return self.config.num_train_timesteps
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
@dataclass
|
| 325 |
-
class ConsistencyFlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 326 |
-
prev_sample: torch.FloatTensor
|
| 327 |
-
pred_original_sample: torch.FloatTensor
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
class ConsistencyFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 331 |
-
_compatibles = []
|
| 332 |
-
order = 1
|
| 333 |
-
|
| 334 |
-
@register_to_config
|
| 335 |
-
def __init__(
|
| 336 |
-
self,
|
| 337 |
-
num_train_timesteps: int = 1000,
|
| 338 |
-
pcm_timesteps: int = 50,
|
| 339 |
-
):
|
| 340 |
-
sigmas = np.linspace(0, 1, num_train_timesteps)
|
| 341 |
-
step_ratio = num_train_timesteps // pcm_timesteps
|
| 342 |
-
|
| 343 |
-
euler_timesteps = (np.arange(1, pcm_timesteps) * step_ratio).round().astype(np.int64) - 1
|
| 344 |
-
euler_timesteps = np.asarray([0] + euler_timesteps.tolist())
|
| 345 |
-
|
| 346 |
-
self.euler_timesteps = euler_timesteps
|
| 347 |
-
self.sigmas = sigmas[self.euler_timesteps]
|
| 348 |
-
self.sigmas = torch.from_numpy((self.sigmas.copy())).to(dtype=torch.float32)
|
| 349 |
-
self.timesteps = self.sigmas * num_train_timesteps
|
| 350 |
-
self._step_index = None
|
| 351 |
-
self._begin_index = None
|
| 352 |
-
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 353 |
-
|
| 354 |
-
@property
|
| 355 |
-
def step_index(self):
|
| 356 |
-
"""
|
| 357 |
-
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 358 |
-
"""
|
| 359 |
-
return self._step_index
|
| 360 |
-
|
| 361 |
-
@property
|
| 362 |
-
def begin_index(self):
|
| 363 |
-
"""
|
| 364 |
-
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 365 |
-
"""
|
| 366 |
-
return self._begin_index
|
| 367 |
-
|
| 368 |
-
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 369 |
-
def set_begin_index(self, begin_index: int = 0):
|
| 370 |
-
"""
|
| 371 |
-
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 372 |
-
|
| 373 |
-
Args:
|
| 374 |
-
begin_index (`int`):
|
| 375 |
-
The begin index for the scheduler.
|
| 376 |
-
"""
|
| 377 |
-
self._begin_index = begin_index
|
| 378 |
-
|
| 379 |
-
def _sigma_to_t(self, sigma):
|
| 380 |
-
return sigma * self.config.num_train_timesteps
|
| 381 |
-
|
| 382 |
-
def set_timesteps(
|
| 383 |
-
self,
|
| 384 |
-
num_inference_steps: int = None,
|
| 385 |
-
device: Union[str, torch.device] = None,
|
| 386 |
-
sigmas: Optional[List[float]] = None,
|
| 387 |
-
):
|
| 388 |
-
"""
|
| 389 |
-
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 390 |
-
|
| 391 |
-
Args:
|
| 392 |
-
num_inference_steps (`int`):
|
| 393 |
-
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 394 |
-
device (`str` or `torch.device`, *optional*):
|
| 395 |
-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 396 |
-
"""
|
| 397 |
-
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else len(sigmas)
|
| 398 |
-
inference_indices = np.linspace(
|
| 399 |
-
0, self.config.pcm_timesteps, num=self.num_inference_steps, endpoint=False
|
| 400 |
-
)
|
| 401 |
-
inference_indices = np.floor(inference_indices).astype(np.int64)
|
| 402 |
-
inference_indices = torch.from_numpy(inference_indices).long()
|
| 403 |
-
|
| 404 |
-
self.sigmas_ = self.sigmas[inference_indices]
|
| 405 |
-
timesteps = self.sigmas_ * self.config.num_train_timesteps
|
| 406 |
-
self.timesteps = timesteps.to(device=device)
|
| 407 |
-
self.sigmas_ = torch.cat(
|
| 408 |
-
[self.sigmas_, torch.ones(1, device=self.sigmas_.device)]
|
| 409 |
-
)
|
| 410 |
-
|
| 411 |
-
self._step_index = None
|
| 412 |
-
self._begin_index = None
|
| 413 |
-
|
| 414 |
-
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 415 |
-
if schedule_timesteps is None:
|
| 416 |
-
schedule_timesteps = self.timesteps
|
| 417 |
-
|
| 418 |
-
indices = (schedule_timesteps == timestep).nonzero()
|
| 419 |
-
|
| 420 |
-
# The sigma index that is taken for the **very** first `step`
|
| 421 |
-
# is always the second index (or the last index if there is only 1)
|
| 422 |
-
# This way we can ensure we don't accidentally skip a sigma in
|
| 423 |
-
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 424 |
-
pos = 1 if len(indices) > 1 else 0
|
| 425 |
-
|
| 426 |
-
return indices[pos].item()
|
| 427 |
-
|
| 428 |
-
def _init_step_index(self, timestep):
|
| 429 |
-
if self.begin_index is None:
|
| 430 |
-
if isinstance(timestep, torch.Tensor):
|
| 431 |
-
timestep = timestep.to(self.timesteps.device)
|
| 432 |
-
self._step_index = self.index_for_timestep(timestep)
|
| 433 |
-
else:
|
| 434 |
-
self._step_index = self._begin_index
|
| 435 |
-
|
| 436 |
-
def step(
|
| 437 |
-
self,
|
| 438 |
-
model_output: torch.FloatTensor,
|
| 439 |
-
timestep: Union[float, torch.FloatTensor],
|
| 440 |
-
sample: torch.FloatTensor,
|
| 441 |
-
generator: Optional[torch.Generator] = None,
|
| 442 |
-
return_dict: bool = True,
|
| 443 |
-
) -> Union[ConsistencyFlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 444 |
-
if (
|
| 445 |
-
isinstance(timestep, int)
|
| 446 |
-
or isinstance(timestep, torch.IntTensor)
|
| 447 |
-
or isinstance(timestep, torch.LongTensor)
|
| 448 |
-
):
|
| 449 |
-
raise ValueError(
|
| 450 |
-
(
|
| 451 |
-
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 452 |
-
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 453 |
-
" one of the `scheduler.timesteps` as a timestep."
|
| 454 |
-
),
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
if self.step_index is None:
|
| 458 |
-
self._init_step_index(timestep)
|
| 459 |
-
|
| 460 |
-
sample = sample.to(torch.float32).to(model_output.device)
|
| 461 |
-
|
| 462 |
-
sigma = self.sigmas_[self.step_index].to(model_output.device)
|
| 463 |
-
sigma_next = self.sigmas_[self.step_index + 1].to(model_output.device)
|
| 464 |
-
|
| 465 |
-
prev_sample = sample + (sigma_next - sigma) * model_output
|
| 466 |
-
prev_sample = prev_sample.to(model_output.dtype)
|
| 467 |
-
|
| 468 |
-
pred_original_sample = sample + (1.0 - sigma) * model_output
|
| 469 |
-
pred_original_sample = pred_original_sample.to(model_output.dtype)
|
| 470 |
-
|
| 471 |
-
self._step_index += 1
|
| 472 |
-
|
| 473 |
-
if not return_dict:
|
| 474 |
-
return (prev_sample,)
|
| 475 |
-
|
| 476 |
-
return ConsistencyFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample,
|
| 477 |
-
pred_original_sample=pred_original_sample)
|
| 478 |
-
|
| 479 |
-
def __len__(self):
|
| 480 |
-
return self.config.num_train_timesteps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/surface_loaders.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import torch
|
| 17 |
-
import trimesh
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def normalize_mesh(mesh, scale=0.9999):
|
| 21 |
-
"""
|
| 22 |
-
Normalize the mesh to fit inside a centered cube with a specified scale.
|
| 23 |
-
|
| 24 |
-
The mesh is translated so that its bounding box center is at the origin,
|
| 25 |
-
then uniformly scaled so that the longest side of the bounding box fits within [-scale, scale].
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
mesh (trimesh.Trimesh): Input mesh to normalize.
|
| 29 |
-
scale (float, optional): Scaling factor to slightly shrink the mesh inside the unit cube. Default is 0.9999.
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
trimesh.Trimesh: The normalized mesh with applied translation and scaling.
|
| 33 |
-
"""
|
| 34 |
-
bbox = mesh.bounds
|
| 35 |
-
center = (bbox[1] + bbox[0]) / 2
|
| 36 |
-
scale_ = (bbox[1] - bbox[0]).max()
|
| 37 |
-
|
| 38 |
-
mesh.apply_translation(-center)
|
| 39 |
-
mesh.apply_scale(1 / scale_ * 2 * scale)
|
| 40 |
-
|
| 41 |
-
return mesh
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def sample_pointcloud(mesh, num=200000):
|
| 45 |
-
"""
|
| 46 |
-
Sample points uniformly from the surface of the mesh along with their corresponding face normals.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
mesh (trimesh.Trimesh): Input mesh to sample from.
|
| 50 |
-
num (int, optional): Number of points to sample. Default is 200000.
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
Tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
-
- points: Sampled points as a float tensor of shape (num, 3).
|
| 55 |
-
- normals: Corresponding normals as a float tensor of shape (num, 3).
|
| 56 |
-
"""
|
| 57 |
-
points, face_idx = mesh.sample(num, return_index=True)
|
| 58 |
-
normals = mesh.face_normals[face_idx]
|
| 59 |
-
points = torch.from_numpy(points.astype(np.float32))
|
| 60 |
-
normals = torch.from_numpy(normals.astype(np.float32))
|
| 61 |
-
return points, normals
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def load_surface(mesh, num_points=8192):
|
| 65 |
-
"""
|
| 66 |
-
Normalize the mesh, sample points and normals from its surface, and randomly select a subset.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
mesh (trimesh.Trimesh): Input mesh to process.
|
| 70 |
-
num_points (int, optional): Number of points to randomly select
|
| 71 |
-
from the sampled surface points. Default is 8192.
|
| 72 |
-
|
| 73 |
-
Returns:
|
| 74 |
-
Tuple[torch.Tensor, trimesh.Trimesh]:
|
| 75 |
-
- surface: Tensor of shape (1, num_points, 6), concatenating points and normals.
|
| 76 |
-
- mesh: The normalized mesh.
|
| 77 |
-
"""
|
| 78 |
-
|
| 79 |
-
mesh = normalize_mesh(mesh, scale=0.98)
|
| 80 |
-
surface, normal = sample_pointcloud(mesh)
|
| 81 |
-
|
| 82 |
-
rng = np.random.default_rng(seed=0)
|
| 83 |
-
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
| 84 |
-
surface = torch.FloatTensor(surface[ind])
|
| 85 |
-
normal = torch.FloatTensor(normal[ind])
|
| 86 |
-
|
| 87 |
-
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0)
|
| 88 |
-
|
| 89 |
-
return surface, mesh
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def sharp_sample_pointcloud(mesh, num=16384):
|
| 93 |
-
"""
|
| 94 |
-
Sample points and normals preferentially from sharp edges of the mesh.
|
| 95 |
-
|
| 96 |
-
Sharp edges are detected based on the angle between vertex normals and face normals.
|
| 97 |
-
Points are sampled along these edges proportionally to edge length.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
mesh (trimesh.Trimesh): Input mesh to sample from.
|
| 101 |
-
num (int, optional): Number of points to sample from sharp edges. Default is 16384.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
Tuple[np.ndarray, np.ndarray]:
|
| 105 |
-
- samples: Sampled points along sharp edges, shape (num, 3).
|
| 106 |
-
- normals: Corresponding interpolated normals, shape (num, 3).
|
| 107 |
-
"""
|
| 108 |
-
V = mesh.vertices
|
| 109 |
-
N = mesh.face_normals
|
| 110 |
-
VN = mesh.vertex_normals
|
| 111 |
-
F = mesh.faces
|
| 112 |
-
VN2 = np.ones(V.shape[0])
|
| 113 |
-
for i in range(3):
|
| 114 |
-
dot = np.stack((VN2[F[:, i]], np.sum(VN[F[:, i]] * N, axis=-1)), axis=-1)
|
| 115 |
-
VN2[F[:, i]] = np.min(dot, axis=-1)
|
| 116 |
-
|
| 117 |
-
sharp_mask = VN2 < 0.985
|
| 118 |
-
# collect edge
|
| 119 |
-
edge_a = np.concatenate((F[:, 0], F[:, 1], F[:, 2]))
|
| 120 |
-
edge_b = np.concatenate((F[:, 1], F[:, 2], F[:, 0]))
|
| 121 |
-
sharp_edge = ((sharp_mask[edge_a] * sharp_mask[edge_b]))
|
| 122 |
-
edge_a = edge_a[sharp_edge > 0]
|
| 123 |
-
edge_b = edge_b[sharp_edge > 0]
|
| 124 |
-
|
| 125 |
-
sharp_verts_a = V[edge_a]
|
| 126 |
-
sharp_verts_b = V[edge_b]
|
| 127 |
-
sharp_verts_an = VN[edge_a]
|
| 128 |
-
sharp_verts_bn = VN[edge_b]
|
| 129 |
-
|
| 130 |
-
weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)
|
| 131 |
-
weights /= np.sum(weights)
|
| 132 |
-
|
| 133 |
-
random_number = np.random.rand(num)
|
| 134 |
-
w = np.random.rand(num, 1)
|
| 135 |
-
index = np.searchsorted(weights.cumsum(), random_number)
|
| 136 |
-
samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]
|
| 137 |
-
normals = w * sharp_verts_an[index] + (1 - w) * sharp_verts_bn[index]
|
| 138 |
-
return samples, normals
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def load_surface_sharpegde(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, normalize_scale=0.9999):
|
| 142 |
-
try:
|
| 143 |
-
mesh_full = trimesh.util.concatenate(mesh.dump())
|
| 144 |
-
except Exception as err:
|
| 145 |
-
mesh_full = trimesh.util.concatenate(mesh)
|
| 146 |
-
mesh_full = normalize_mesh(mesh_full, scale=normalize_scale)
|
| 147 |
-
|
| 148 |
-
origin_num = mesh_full.faces.shape[0]
|
| 149 |
-
original_vertices = mesh_full.vertices
|
| 150 |
-
original_faces = mesh_full.faces
|
| 151 |
-
|
| 152 |
-
mesh = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[:origin_num])
|
| 153 |
-
mesh_fill = trimesh.Trimesh(vertices=original_vertices, faces=original_faces[origin_num:])
|
| 154 |
-
area = mesh.area
|
| 155 |
-
area_fill = mesh_fill.area
|
| 156 |
-
sample_num = 819200 // 2 # 499712 // 2
|
| 157 |
-
num_fill = int(sample_num * (area_fill / (area + area_fill)))
|
| 158 |
-
num = sample_num - num_fill
|
| 159 |
-
|
| 160 |
-
random_surface, random_normal = sample_pointcloud(mesh, num=num)
|
| 161 |
-
if num_fill == 0:
|
| 162 |
-
random_surface_fill, random_normal_fill = np.zeros((0, 3)), np.zeros((0, 3))
|
| 163 |
-
else:
|
| 164 |
-
random_surface_fill, random_normal_fill = sample_pointcloud(mesh_fill, num=num_fill)
|
| 165 |
-
random_sharp_surface, sharp_normal = sharp_sample_pointcloud(mesh, num=sample_num)
|
| 166 |
-
|
| 167 |
-
# save_surface
|
| 168 |
-
surface = np.concatenate((random_surface, random_normal), axis=1).astype(np.float16)
|
| 169 |
-
surface_fill = np.concatenate((random_surface_fill, random_normal_fill), axis=1).astype(np.float16)
|
| 170 |
-
sharp_surface = np.concatenate((random_sharp_surface, sharp_normal), axis=1).astype(np.float16)
|
| 171 |
-
surface = np.concatenate((surface, surface_fill), axis=0)
|
| 172 |
-
if sharpedge_flag:
|
| 173 |
-
sharpedge_label = np.zeros((surface.shape[0], 1))
|
| 174 |
-
surface = np.concatenate((surface, sharpedge_label), axis=1)
|
| 175 |
-
sharpedge_label = np.ones((sharp_surface.shape[0], 1))
|
| 176 |
-
sharp_surface = np.concatenate((sharp_surface, sharpedge_label), axis=1)
|
| 177 |
-
rng = np.random.default_rng()
|
| 178 |
-
ind = rng.choice(surface.shape[0], num_points, replace=False)
|
| 179 |
-
surface = torch.FloatTensor(surface[ind])
|
| 180 |
-
ind = rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)
|
| 181 |
-
sharp_surface = torch.FloatTensor(sharp_surface[ind])
|
| 182 |
-
|
| 183 |
-
return torch.cat([surface, sharp_surface], dim=0).unsqueeze(0), mesh_full
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class SurfaceLoader:
|
| 187 |
-
def __init__(self, num_points=8192):
|
| 188 |
-
self.num_points = num_points
|
| 189 |
-
|
| 190 |
-
def __call__(self, mesh_or_mesh_path, num_points=None):
|
| 191 |
-
if num_points is None:
|
| 192 |
-
num_points = self.num_points
|
| 193 |
-
|
| 194 |
-
mesh = mesh_or_mesh_path
|
| 195 |
-
if isinstance(mesh, str):
|
| 196 |
-
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
| 197 |
-
if isinstance(mesh, trimesh.scene.Scene):
|
| 198 |
-
for idx, obj in enumerate(mesh.geometry.values()):
|
| 199 |
-
if idx == 0:
|
| 200 |
-
temp_mesh = obj
|
| 201 |
-
else:
|
| 202 |
-
temp_mesh = temp_mesh + obj
|
| 203 |
-
mesh = temp_mesh
|
| 204 |
-
surface, mesh = load_surface(mesh, num_points=num_points)
|
| 205 |
-
return surface
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class SharpEdgeSurfaceLoader:
|
| 209 |
-
def __init__(self, num_uniform_points=8192, num_sharp_points=8192, **kwargs):
|
| 210 |
-
self.num_uniform_points = num_uniform_points
|
| 211 |
-
self.num_sharp_points = num_sharp_points
|
| 212 |
-
self.num_points = num_uniform_points + num_sharp_points
|
| 213 |
-
|
| 214 |
-
def __call__(self, mesh_or_mesh_path, num_uniform_points=None,
|
| 215 |
-
num_sharp_points=None, normalize_scale=0.9999):
|
| 216 |
-
if num_uniform_points is None:
|
| 217 |
-
num_uniform_points = self.num_uniform_points
|
| 218 |
-
if num_sharp_points is None:
|
| 219 |
-
num_sharp_points = self.num_sharp_points
|
| 220 |
-
|
| 221 |
-
mesh = mesh_or_mesh_path
|
| 222 |
-
if isinstance(mesh, str):
|
| 223 |
-
mesh = trimesh.load(mesh, force="mesh", merge_primitives=True)
|
| 224 |
-
if isinstance(mesh, trimesh.scene.Scene):
|
| 225 |
-
for idx, obj in enumerate(mesh.geometry.values()):
|
| 226 |
-
if idx == 0:
|
| 227 |
-
temp_mesh = obj
|
| 228 |
-
else:
|
| 229 |
-
temp_mesh = temp_mesh + obj
|
| 230 |
-
mesh = temp_mesh
|
| 231 |
-
surface, mesh = load_surface_sharpegde(mesh, num_points=num_uniform_points,
|
| 232 |
-
num_sharp_points=num_sharp_points, normalize_scale=normalize_scale)
|
| 233 |
-
return surface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
from .misc import get_config_from_file
|
| 4 |
-
from .misc import instantiate_from_config
|
| 5 |
-
from .utils import get_logger, logger, synchronize_timer, smart_load_model
|
| 6 |
-
from .voxelize import voxelize_from_point
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/ema.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class LitEma(nn.Module):
|
| 6 |
-
def __init__(self, model, decay=0.9999, use_num_updates=True):
|
| 7 |
-
super().__init__()
|
| 8 |
-
if decay < 0.0 or decay > 1.0:
|
| 9 |
-
raise ValueError('Decay must be between 0 and 1')
|
| 10 |
-
|
| 11 |
-
self.m_name2s_name = {}
|
| 12 |
-
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
| 13 |
-
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
|
| 14 |
-
else torch.tensor(-1, dtype=torch.int))
|
| 15 |
-
|
| 16 |
-
for name, p in model.named_parameters():
|
| 17 |
-
if p.requires_grad:
|
| 18 |
-
# remove as '.'-character is not allowed in buffers
|
| 19 |
-
s_name = name.replace('.', '_____')
|
| 20 |
-
self.m_name2s_name.update({name: s_name})
|
| 21 |
-
self.register_buffer(s_name, p.clone().detach().data)
|
| 22 |
-
|
| 23 |
-
self.collected_params = []
|
| 24 |
-
|
| 25 |
-
def forward(self, model):
|
| 26 |
-
decay = self.decay
|
| 27 |
-
|
| 28 |
-
if self.num_updates >= 0:
|
| 29 |
-
self.num_updates += 1
|
| 30 |
-
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
| 31 |
-
|
| 32 |
-
one_minus_decay = 1.0 - decay
|
| 33 |
-
|
| 34 |
-
with torch.no_grad():
|
| 35 |
-
m_param = dict(model.named_parameters())
|
| 36 |
-
shadow_params = dict(self.named_buffers())
|
| 37 |
-
|
| 38 |
-
for key in m_param:
|
| 39 |
-
if m_param[key].requires_grad:
|
| 40 |
-
sname = self.m_name2s_name[key]
|
| 41 |
-
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
| 42 |
-
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
| 43 |
-
else:
|
| 44 |
-
assert not key in self.m_name2s_name
|
| 45 |
-
|
| 46 |
-
def copy_to(self, model):
|
| 47 |
-
m_param = dict(model.named_parameters())
|
| 48 |
-
shadow_params = dict(self.named_buffers())
|
| 49 |
-
for key in m_param:
|
| 50 |
-
if m_param[key].requires_grad:
|
| 51 |
-
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
| 52 |
-
else:
|
| 53 |
-
assert not key in self.m_name2s_name
|
| 54 |
-
|
| 55 |
-
def store(self, model):
|
| 56 |
-
"""
|
| 57 |
-
Save the current parameters for restoring later.
|
| 58 |
-
Args:
|
| 59 |
-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 60 |
-
temporarily stored.
|
| 61 |
-
"""
|
| 62 |
-
self.collected_params = [param.clone() for param in model.parameters()]
|
| 63 |
-
|
| 64 |
-
def restore(self, model):
|
| 65 |
-
"""
|
| 66 |
-
Restore the parameters stored with the `store` method.
|
| 67 |
-
Useful to validate the model with EMA parameters without affecting the
|
| 68 |
-
original optimization process. Store the parameters before the
|
| 69 |
-
`copy_to` method. After validation (or model saving), use this to
|
| 70 |
-
restore the former parameters.
|
| 71 |
-
Args:
|
| 72 |
-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 73 |
-
updated with the stored parameters.
|
| 74 |
-
"""
|
| 75 |
-
for c_param, param in zip(self.collected_params, model.parameters()):
|
| 76 |
-
param.data.copy_(c_param.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/misc.py
DELETED
|
@@ -1,200 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
import importlib
|
| 4 |
-
from omegaconf import OmegaConf, DictConfig, ListConfig
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.distributed as dist
|
| 8 |
-
from typing import Union
|
| 9 |
-
from .utils import logger
|
| 10 |
-
import os
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]:
|
| 14 |
-
config_file = OmegaConf.load(config_file)
|
| 15 |
-
|
| 16 |
-
if 'base_config' in config_file.keys():
|
| 17 |
-
if config_file['base_config'] == "default_base":
|
| 18 |
-
base_config = OmegaConf.create()
|
| 19 |
-
# base_config = get_default_config()
|
| 20 |
-
elif config_file['base_config'].endswith(".yaml"):
|
| 21 |
-
base_config = get_config_from_file(config_file['base_config'])
|
| 22 |
-
else:
|
| 23 |
-
raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.")
|
| 24 |
-
|
| 25 |
-
config_file = {key: value for key, value in config_file if key != "base_config"}
|
| 26 |
-
|
| 27 |
-
return OmegaConf.merge(base_config, config_file)
|
| 28 |
-
|
| 29 |
-
return config_file
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def get_obj_from_str(string, reload=False):
|
| 33 |
-
module, cls = string.rsplit(".", 1)
|
| 34 |
-
if reload:
|
| 35 |
-
module_imp = importlib.import_module(module)
|
| 36 |
-
importlib.reload(module_imp)
|
| 37 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def get_obj_from_config(config):
|
| 41 |
-
if "target" not in config:
|
| 42 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 43 |
-
|
| 44 |
-
return get_obj_from_str(config["target"])
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def instantiate_from_config(config, **kwargs):
|
| 48 |
-
if "target" not in config:
|
| 49 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 50 |
-
|
| 51 |
-
cls = get_obj_from_str(config["target"])
|
| 52 |
-
|
| 53 |
-
if config.get("from_pretrained", None):
|
| 54 |
-
return cls.from_pretrained(
|
| 55 |
-
config["from_pretrained"],
|
| 56 |
-
use_safetensors=config.get('use_safetensors', False),
|
| 57 |
-
variant=config.get('variant', 'fp16'))
|
| 58 |
-
|
| 59 |
-
params = config.get("params", dict())
|
| 60 |
-
# params.update(kwargs)
|
| 61 |
-
# instance = cls(**params)
|
| 62 |
-
kwargs.update(params)
|
| 63 |
-
instance = cls(**kwargs)
|
| 64 |
-
|
| 65 |
-
return instance
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def instantiate_vae_from_config(config, **kwargs):
|
| 69 |
-
if "target" not in config:
|
| 70 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 71 |
-
|
| 72 |
-
cls = get_obj_from_str(config["target"])
|
| 73 |
-
|
| 74 |
-
if config.get("from_pretrained", None):
|
| 75 |
-
return cls.from_pretrained(
|
| 76 |
-
config["from_pretrained"],
|
| 77 |
-
params=config.get("params", dict()),
|
| 78 |
-
use_safetensors=config.get('use_safetensors', False),
|
| 79 |
-
variant=config.get('variant', 'fp16'))
|
| 80 |
-
|
| 81 |
-
params = config.get("params", dict())
|
| 82 |
-
kwargs.update(params)
|
| 83 |
-
instance = cls(**kwargs)
|
| 84 |
-
|
| 85 |
-
return instance
|
| 86 |
-
|
| 87 |
-
def instantiate_vae_from_config_local(config, **kwargs):
|
| 88 |
-
if "target" not in config:
|
| 89 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 90 |
-
|
| 91 |
-
cls = get_obj_from_str(config["target"])
|
| 92 |
-
|
| 93 |
-
if not config.get("from_pretrained", None):
|
| 94 |
-
raise FileNotFoundError(f"Need from_pretrained!")
|
| 95 |
-
|
| 96 |
-
ckpt_path = config["from_pretrained"]
|
| 97 |
-
|
| 98 |
-
logger.info(f"Loading model from {ckpt_path}")
|
| 99 |
-
if not os.path.exists(ckpt_path):
|
| 100 |
-
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
| 101 |
-
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
| 102 |
-
|
| 103 |
-
if 'state_dict' not in ckpt:
|
| 104 |
-
# deepspeed ckpt
|
| 105 |
-
state_dict = {}
|
| 106 |
-
for k in ckpt.keys():
|
| 107 |
-
new_k = k.replace('vae_model.', '')
|
| 108 |
-
state_dict[new_k] = ckpt[k]
|
| 109 |
-
else:
|
| 110 |
-
state_dict = ckpt["state_dict"]
|
| 111 |
-
|
| 112 |
-
params = config.get("params", dict())
|
| 113 |
-
kwargs.update(params)
|
| 114 |
-
instance = cls(**kwargs)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
missing, unexpected = instance.load_state_dict(state_dict)
|
| 118 |
-
print(f"VAE Missing Keys: {missing}")
|
| 119 |
-
print(f"VAE Unexpected Keys: {unexpected}")
|
| 120 |
-
|
| 121 |
-
return instance
|
| 122 |
-
|
| 123 |
-
def disabled_train(self, mode=True):
|
| 124 |
-
"""Overwrite model.train with this function to make sure train/eval mode
|
| 125 |
-
does not change anymore."""
|
| 126 |
-
return self
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def instantiate_non_trainable_model(config):
|
| 130 |
-
model = instantiate_from_config(config)
|
| 131 |
-
model = model.eval()
|
| 132 |
-
model.train = disabled_train
|
| 133 |
-
for param in model.parameters():
|
| 134 |
-
param.requires_grad = False
|
| 135 |
-
|
| 136 |
-
return model
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def instantiate_vae_model(config, requires_grad=False):
|
| 140 |
-
model = instantiate_vae_from_config(config)
|
| 141 |
-
model = model.eval()
|
| 142 |
-
model.train = disabled_train
|
| 143 |
-
for param in model.parameters():
|
| 144 |
-
param.requires_grad = requires_grad
|
| 145 |
-
|
| 146 |
-
return model
|
| 147 |
-
|
| 148 |
-
def instantiate_vae_model_local(config, requires_grad=False):
|
| 149 |
-
model = instantiate_vae_from_config_local(config)
|
| 150 |
-
model = model.eval()
|
| 151 |
-
model.train = disabled_train
|
| 152 |
-
for param in model.parameters():
|
| 153 |
-
param.requires_grad = requires_grad
|
| 154 |
-
|
| 155 |
-
return model
|
| 156 |
-
|
| 157 |
-
def is_dist_avail_and_initialized():
|
| 158 |
-
if not dist.is_available():
|
| 159 |
-
return False
|
| 160 |
-
if not dist.is_initialized():
|
| 161 |
-
return False
|
| 162 |
-
return True
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def get_rank():
|
| 166 |
-
if not is_dist_avail_and_initialized():
|
| 167 |
-
return 0
|
| 168 |
-
return dist.get_rank()
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def get_world_size():
|
| 172 |
-
if not is_dist_avail_and_initialized():
|
| 173 |
-
return 1
|
| 174 |
-
return dist.get_world_size()
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def all_gather_batch(tensors):
|
| 178 |
-
"""
|
| 179 |
-
Performs all_gather operation on the provided tensors.
|
| 180 |
-
"""
|
| 181 |
-
# Queue the gathered tensors
|
| 182 |
-
world_size = get_world_size()
|
| 183 |
-
# There is no need for reduction in the single-proc case
|
| 184 |
-
if world_size == 1:
|
| 185 |
-
return tensors
|
| 186 |
-
tensor_list = []
|
| 187 |
-
output_tensor = []
|
| 188 |
-
for tensor in tensors:
|
| 189 |
-
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
| 190 |
-
dist.all_gather(
|
| 191 |
-
tensor_all,
|
| 192 |
-
tensor,
|
| 193 |
-
async_op=False # performance opt
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
tensor_list.append(tensor_all)
|
| 197 |
-
|
| 198 |
-
for tensor_all in tensor_list:
|
| 199 |
-
output_tensor.append(torch.cat(tensor_all, dim=0))
|
| 200 |
-
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/trainings/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
|
|
|
|
|
ultrashape/utils/trainings/callback.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
# ------------------------------------------------------------------------------------
|
| 2 |
-
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
|
| 3 |
-
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
| 4 |
-
# ------------------------------------------------------------------------------------
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import time
|
| 8 |
-
import wandb
|
| 9 |
-
import numpy as np
|
| 10 |
-
from PIL import Image
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from omegaconf import OmegaConf, DictConfig
|
| 13 |
-
from typing import Tuple, Generic, Dict, Callable, Optional, Any
|
| 14 |
-
from pprint import pprint
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
import torchvision
|
| 18 |
-
import pytorch_lightning as pl
|
| 19 |
-
import pytorch_lightning.loggers
|
| 20 |
-
from pytorch_lightning.loggers import WandbLogger
|
| 21 |
-
from pytorch_lightning.loggers.logger import DummyLogger
|
| 22 |
-
from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
|
| 23 |
-
from pytorch_lightning.callbacks import Callback
|
| 24 |
-
|
| 25 |
-
from functools import wraps
|
| 26 |
-
|
| 27 |
-
def node_zero_only(fn: Callable) -> Callable:
|
| 28 |
-
@wraps(fn)
|
| 29 |
-
def wrapped_fn(*args, **kwargs) -> Optional[Any]:
|
| 30 |
-
if node_zero_only.node == 0:
|
| 31 |
-
return fn(*args, **kwargs)
|
| 32 |
-
return None
|
| 33 |
-
return wrapped_fn
|
| 34 |
-
|
| 35 |
-
node_zero_only.node = getattr(node_zero_only, 'node', int(os.environ.get('NODE_RANK', 0)))
|
| 36 |
-
|
| 37 |
-
def node_zero_experiment(fn: Callable) -> Callable:
|
| 38 |
-
"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""
|
| 39 |
-
@wraps(fn)
|
| 40 |
-
def experiment(self):
|
| 41 |
-
@node_zero_only
|
| 42 |
-
def get_experiment():
|
| 43 |
-
return fn(self)
|
| 44 |
-
return get_experiment() or DummyLogger.experiment
|
| 45 |
-
return experiment
|
| 46 |
-
|
| 47 |
-
# customize wandb for node 0 only
|
| 48 |
-
class MyWandbLogger(WandbLogger):
|
| 49 |
-
@WandbLogger.experiment.getter
|
| 50 |
-
@node_zero_experiment
|
| 51 |
-
def experiment(self):
|
| 52 |
-
return super().experiment
|
| 53 |
-
|
| 54 |
-
class SetupCallback(Callback):
|
| 55 |
-
def __init__(self, config: DictConfig, exp_config: DictConfig,
|
| 56 |
-
basedir: Path, logdir: str = "log", ckptdir: str = "ckpt") -> None:
|
| 57 |
-
super().__init__()
|
| 58 |
-
self.logdir = basedir / logdir
|
| 59 |
-
self.ckptdir = basedir / ckptdir
|
| 60 |
-
self.config = config
|
| 61 |
-
self.exp_config = exp_config
|
| 62 |
-
|
| 63 |
-
# def on_pretrain_routine_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
| 64 |
-
# if trainer.global_rank == 0:
|
| 65 |
-
# # Create logdirs and save configs
|
| 66 |
-
# os.makedirs(self.logdir, exist_ok=True)
|
| 67 |
-
# os.makedirs(self.ckptdir, exist_ok=True)
|
| 68 |
-
#
|
| 69 |
-
# print("Experiment config")
|
| 70 |
-
# print(self.exp_config.pretty())
|
| 71 |
-
#
|
| 72 |
-
# print("Model config")
|
| 73 |
-
# print(self.config.pretty())
|
| 74 |
-
|
| 75 |
-
def on_fit_start(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule) -> None:
|
| 76 |
-
if trainer.global_rank == 0:
|
| 77 |
-
# Create logdirs and save configs
|
| 78 |
-
os.makedirs(self.logdir, exist_ok=True)
|
| 79 |
-
os.makedirs(self.ckptdir, exist_ok=True)
|
| 80 |
-
|
| 81 |
-
# print("Experiment config")
|
| 82 |
-
# pprint(self.exp_config)
|
| 83 |
-
#
|
| 84 |
-
# print("Model config")
|
| 85 |
-
# pprint(self.config)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
class ImageLogger(Callback):
|
| 89 |
-
def __init__(self, batch_frequency: int, max_images: int, clamp: bool = True,
|
| 90 |
-
increase_log_steps: bool = True) -> None:
|
| 91 |
-
|
| 92 |
-
super().__init__()
|
| 93 |
-
self.batch_freq = batch_frequency
|
| 94 |
-
self.max_images = max_images
|
| 95 |
-
self.logger_log_images = {
|
| 96 |
-
pl.loggers.WandbLogger: self._wandb,
|
| 97 |
-
pl.loggers.TestTubeLogger: self._testtube,
|
| 98 |
-
}
|
| 99 |
-
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
| 100 |
-
if not increase_log_steps:
|
| 101 |
-
self.log_steps = [self.batch_freq]
|
| 102 |
-
self.clamp = clamp
|
| 103 |
-
|
| 104 |
-
@rank_zero_only
|
| 105 |
-
def _wandb(self, pl_module, images, batch_idx, split):
|
| 106 |
-
# raise ValueError("No way wandb")
|
| 107 |
-
grids = dict()
|
| 108 |
-
for k in images:
|
| 109 |
-
grid = torchvision.utils.make_grid(images[k])
|
| 110 |
-
grids[f"{split}/{k}"] = wandb.Image(grid)
|
| 111 |
-
pl_module.logger.experiment.log(grids)
|
| 112 |
-
|
| 113 |
-
@rank_zero_only
|
| 114 |
-
def _testtube(self, pl_module, images, batch_idx, split):
|
| 115 |
-
for k in images:
|
| 116 |
-
grid = torchvision.utils.make_grid(images[k])
|
| 117 |
-
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
| 118 |
-
|
| 119 |
-
tag = f"{split}/{k}"
|
| 120 |
-
pl_module.logger.experiment.add_image(
|
| 121 |
-
tag, grid,
|
| 122 |
-
global_step=pl_module.global_step)
|
| 123 |
-
|
| 124 |
-
@rank_zero_only
|
| 125 |
-
def log_local(self, save_dir: str, split: str, images: Dict,
|
| 126 |
-
global_step: int, current_epoch: int, batch_idx: int) -> None:
|
| 127 |
-
root = os.path.join(save_dir, "results", split)
|
| 128 |
-
os.makedirs(root, exist_ok=True)
|
| 129 |
-
for k in images:
|
| 130 |
-
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
| 131 |
-
|
| 132 |
-
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 133 |
-
grid = grid.numpy()
|
| 134 |
-
grid = (grid * 255).astype(np.uint8)
|
| 135 |
-
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
| 136 |
-
k,
|
| 137 |
-
global_step,
|
| 138 |
-
current_epoch,
|
| 139 |
-
batch_idx)
|
| 140 |
-
path = os.path.join(root, filename)
|
| 141 |
-
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
| 142 |
-
Image.fromarray(grid).save(path)
|
| 143 |
-
|
| 144 |
-
def log_img(self, pl_module: pl.LightningModule, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int,
|
| 145 |
-
split: str = "train") -> None:
|
| 146 |
-
if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
|
| 147 |
-
hasattr(pl_module, "log_images") and
|
| 148 |
-
callable(pl_module.log_images) and
|
| 149 |
-
self.max_images > 0):
|
| 150 |
-
logger = type(pl_module.logger)
|
| 151 |
-
|
| 152 |
-
is_train = pl_module.training
|
| 153 |
-
if is_train:
|
| 154 |
-
pl_module.eval()
|
| 155 |
-
|
| 156 |
-
with torch.no_grad():
|
| 157 |
-
images = pl_module.log_images(batch, split=split, pl_module=pl_module)
|
| 158 |
-
|
| 159 |
-
for k in images:
|
| 160 |
-
N = min(images[k].shape[0], self.max_images)
|
| 161 |
-
images[k] = images[k][:N].detach().cpu()
|
| 162 |
-
if self.clamp:
|
| 163 |
-
images[k] = images[k].clamp(0, 1)
|
| 164 |
-
|
| 165 |
-
self.log_local(pl_module.logger.save_dir, split, images,
|
| 166 |
-
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
| 167 |
-
|
| 168 |
-
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
| 169 |
-
logger_log_images(pl_module, images, pl_module.global_step, split)
|
| 170 |
-
|
| 171 |
-
if is_train:
|
| 172 |
-
pl_module.train()
|
| 173 |
-
|
| 174 |
-
def check_frequency(self, batch_idx: int) -> bool:
|
| 175 |
-
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
|
| 176 |
-
try:
|
| 177 |
-
self.log_steps.pop(0)
|
| 178 |
-
except IndexError:
|
| 179 |
-
pass
|
| 180 |
-
return True
|
| 181 |
-
return False
|
| 182 |
-
|
| 183 |
-
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
| 184 |
-
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor], batch_idx: int) -> None:
|
| 185 |
-
self.log_img(pl_module, batch, batch_idx, split="train")
|
| 186 |
-
|
| 187 |
-
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
| 188 |
-
outputs: Generic, batch: Tuple[torch.LongTensor, torch.FloatTensor],
|
| 189 |
-
dataloader_idx: int, batch_idx: int) -> None:
|
| 190 |
-
self.log_img(pl_module, batch, batch_idx, split="val")
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
class CUDACallback(Callback):
|
| 194 |
-
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
| 195 |
-
def on_train_epoch_start(self, trainer, pl_module):
|
| 196 |
-
# Reset the memory use counter
|
| 197 |
-
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
| 198 |
-
torch.cuda.synchronize(trainer.root_gpu)
|
| 199 |
-
self.start_time = time.time()
|
| 200 |
-
|
| 201 |
-
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
| 202 |
-
torch.cuda.synchronize(trainer.root_gpu)
|
| 203 |
-
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
| 204 |
-
epoch_time = time.time() - self.start_time
|
| 205 |
-
|
| 206 |
-
try:
|
| 207 |
-
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
| 208 |
-
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
| 209 |
-
|
| 210 |
-
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
| 211 |
-
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
| 212 |
-
except AttributeError:
|
| 213 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/trainings/lr_scheduler.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class BaseScheduler(object):
|
| 19 |
-
|
| 20 |
-
def schedule(self, n, **kwargs):
|
| 21 |
-
raise NotImplementedError
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class LambdaWarmUpCosineFactorScheduler(BaseScheduler):
|
| 25 |
-
"""
|
| 26 |
-
note: use with a base_lr of 1.0
|
| 27 |
-
"""
|
| 28 |
-
def __init__(self, warm_up_steps, f_min, f_max, f_start, max_decay_steps, verbosity_interval=0, **ignore_kwargs):
|
| 29 |
-
self.lr_warm_up_steps = warm_up_steps
|
| 30 |
-
self.f_start = f_start
|
| 31 |
-
self.f_min = f_min
|
| 32 |
-
self.f_max = f_max
|
| 33 |
-
self.lr_max_decay_steps = max_decay_steps
|
| 34 |
-
self.last_f = 0.
|
| 35 |
-
self.verbosity_interval = verbosity_interval
|
| 36 |
-
|
| 37 |
-
def schedule(self, n, **kwargs):
|
| 38 |
-
if self.verbosity_interval > 0:
|
| 39 |
-
if n % self.verbosity_interval == 0:
|
| 40 |
-
print(f"current step: {n}, recent lr-multiplier: {self.f_start}")
|
| 41 |
-
if n < self.lr_warm_up_steps:
|
| 42 |
-
f = (self.f_max - self.f_start) / self.lr_warm_up_steps * n + self.f_start
|
| 43 |
-
self.last_f = f
|
| 44 |
-
return f
|
| 45 |
-
else:
|
| 46 |
-
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
| 47 |
-
t = min(t, 1.0)
|
| 48 |
-
f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi))
|
| 49 |
-
self.last_f = f
|
| 50 |
-
return f
|
| 51 |
-
|
| 52 |
-
def __call__(self, n, **kwargs):
|
| 53 |
-
return self.schedule(n, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/trainings/mesh.py
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 4 |
-
# except for the third-party components listed below.
|
| 5 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 6 |
-
# in the repsective licenses of these third-party components.
|
| 7 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 8 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 9 |
-
# all relevant laws and regulations.
|
| 10 |
-
|
| 11 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 12 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 13 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 14 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 15 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 16 |
-
|
| 17 |
-
import os
|
| 18 |
-
import cv2
|
| 19 |
-
import numpy as np
|
| 20 |
-
import PIL.Image
|
| 21 |
-
from typing import Optional
|
| 22 |
-
|
| 23 |
-
import trimesh
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def save_obj(pointnp_px3, facenp_fx3, fname):
|
| 27 |
-
fid = open(fname, "w")
|
| 28 |
-
write_str = ""
|
| 29 |
-
for pidx, p in enumerate(pointnp_px3):
|
| 30 |
-
pp = p
|
| 31 |
-
write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
|
| 32 |
-
|
| 33 |
-
for i, f in enumerate(facenp_fx3):
|
| 34 |
-
f1 = f + 1
|
| 35 |
-
write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
|
| 36 |
-
fid.write(write_str)
|
| 37 |
-
fid.close()
|
| 38 |
-
return
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
|
| 42 |
-
fol, na = os.path.split(fname)
|
| 43 |
-
na, _ = os.path.splitext(na)
|
| 44 |
-
|
| 45 |
-
matname = "%s/%s.mtl" % (fol, na)
|
| 46 |
-
fid = open(matname, "w")
|
| 47 |
-
fid.write("newmtl material_0\n")
|
| 48 |
-
fid.write("Kd 1 1 1\n")
|
| 49 |
-
fid.write("Ka 0 0 0\n")
|
| 50 |
-
fid.write("Ks 0.4 0.4 0.4\n")
|
| 51 |
-
fid.write("Ns 10\n")
|
| 52 |
-
fid.write("illum 2\n")
|
| 53 |
-
fid.write("map_Kd %s.png\n" % na)
|
| 54 |
-
fid.close()
|
| 55 |
-
####
|
| 56 |
-
|
| 57 |
-
fid = open(fname, "w")
|
| 58 |
-
fid.write("mtllib %s.mtl\n" % na)
|
| 59 |
-
|
| 60 |
-
for pidx, p3 in enumerate(pointnp_px3):
|
| 61 |
-
pp = p3
|
| 62 |
-
fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
|
| 63 |
-
|
| 64 |
-
for pidx, p2 in enumerate(tcoords_px2):
|
| 65 |
-
pp = p2
|
| 66 |
-
fid.write("vt %f %f\n" % (pp[0], pp[1]))
|
| 67 |
-
|
| 68 |
-
fid.write("usemtl material_0\n")
|
| 69 |
-
for i, f in enumerate(facenp_fx3):
|
| 70 |
-
f1 = f + 1
|
| 71 |
-
f2 = facetex_fx3[i] + 1
|
| 72 |
-
fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
|
| 73 |
-
fid.close()
|
| 74 |
-
|
| 75 |
-
PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
|
| 76 |
-
os.path.join(fol, "%s.png" % na))
|
| 77 |
-
|
| 78 |
-
return
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class MeshOutput(object):
|
| 82 |
-
|
| 83 |
-
def __init__(self,
|
| 84 |
-
mesh_v: np.ndarray,
|
| 85 |
-
mesh_f: np.ndarray,
|
| 86 |
-
vertex_colors: Optional[np.ndarray] = None,
|
| 87 |
-
uvs: Optional[np.ndarray] = None,
|
| 88 |
-
mesh_tex_idx: Optional[np.ndarray] = None,
|
| 89 |
-
tex_map: Optional[np.ndarray] = None):
|
| 90 |
-
|
| 91 |
-
self.mesh_v = mesh_v
|
| 92 |
-
self.mesh_f = mesh_f
|
| 93 |
-
self.vertex_colors = vertex_colors
|
| 94 |
-
self.uvs = uvs
|
| 95 |
-
self.mesh_tex_idx = mesh_tex_idx
|
| 96 |
-
self.tex_map = tex_map
|
| 97 |
-
|
| 98 |
-
def contain_uv_texture(self):
|
| 99 |
-
return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
|
| 100 |
-
|
| 101 |
-
def contain_vertex_colors(self):
|
| 102 |
-
return self.vertex_colors is not None
|
| 103 |
-
|
| 104 |
-
def export(self, fname):
|
| 105 |
-
|
| 106 |
-
if self.contain_uv_texture():
|
| 107 |
-
savemeshtes2(
|
| 108 |
-
self.mesh_v,
|
| 109 |
-
self.uvs,
|
| 110 |
-
self.mesh_f,
|
| 111 |
-
self.mesh_tex_idx,
|
| 112 |
-
self.tex_map,
|
| 113 |
-
fname
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
elif self.contain_vertex_colors():
|
| 117 |
-
mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
|
| 118 |
-
mesh_obj.export(fname)
|
| 119 |
-
|
| 120 |
-
else:
|
| 121 |
-
save_obj(
|
| 122 |
-
self.mesh_v,
|
| 123 |
-
self.mesh_f,
|
| 124 |
-
fname
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/trainings/mesh_log_callback.py
DELETED
|
@@ -1,342 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 4 |
-
# except for the third-party components listed below.
|
| 5 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 6 |
-
# in the repsective licenses of these third-party components.
|
| 7 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 8 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 9 |
-
# all relevant laws and regulations.
|
| 10 |
-
|
| 11 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 12 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 13 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 14 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 15 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 16 |
-
|
| 17 |
-
import json
|
| 18 |
-
import math
|
| 19 |
-
import os
|
| 20 |
-
from typing import Tuple, Generic, Dict, List, Union, Optional
|
| 21 |
-
|
| 22 |
-
import trimesh
|
| 23 |
-
import numpy as np
|
| 24 |
-
import pytorch_lightning as pl
|
| 25 |
-
import pytorch_lightning.loggers
|
| 26 |
-
import torch
|
| 27 |
-
import torchvision
|
| 28 |
-
from pytorch_lightning.callbacks import Callback
|
| 29 |
-
from pytorch_lightning.utilities import rank_zero_only
|
| 30 |
-
|
| 31 |
-
from hy3dshape.pipelines import export_to_trimesh
|
| 32 |
-
from hy3dshape.utils.trainings.mesh import MeshOutput
|
| 33 |
-
from hy3dshape.utils.visualizers import html_util
|
| 34 |
-
from hy3dshape.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class ImageConditionalASLDiffuserLogger(Callback):
|
| 38 |
-
def __init__(self,
|
| 39 |
-
step_frequency: int,
|
| 40 |
-
num_samples: int = 1,
|
| 41 |
-
mean: Optional[Union[List[float], Tuple[float]]] = None,
|
| 42 |
-
std: Optional[Union[List[float], Tuple[float]]] = None,
|
| 43 |
-
bounds: Union[List[float], Tuple[float]] = (-1.1, -1.1, -1.1, 1.1, 1.1, 1.1),
|
| 44 |
-
**kwargs) -> None:
|
| 45 |
-
|
| 46 |
-
super().__init__()
|
| 47 |
-
self.bbox_size = np.array(bounds[3:6]) - np.array(bounds[0:3])
|
| 48 |
-
|
| 49 |
-
if mean is not None:
|
| 50 |
-
mean = np.asarray(mean)
|
| 51 |
-
|
| 52 |
-
if std is not None:
|
| 53 |
-
std = np.asarray(std)
|
| 54 |
-
|
| 55 |
-
self.mean = mean
|
| 56 |
-
self.std = std
|
| 57 |
-
|
| 58 |
-
self.step_freq = step_frequency
|
| 59 |
-
self.num_samples = num_samples
|
| 60 |
-
self.has_train_logged = False
|
| 61 |
-
self.logger_log_images = {
|
| 62 |
-
pl.loggers.WandbLogger: self._wandb,
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
| 66 |
-
|
| 67 |
-
@rank_zero_only
|
| 68 |
-
def _wandb(self, pl_module, images, batch_idx, split):
|
| 69 |
-
# raise ValueError("No way wandb")
|
| 70 |
-
grids = dict()
|
| 71 |
-
for k in images:
|
| 72 |
-
grid = torchvision.utils.make_grid(images[k])
|
| 73 |
-
grids[f"{split}/{k}"] = wandb.Image(grid)
|
| 74 |
-
pl_module.logger.experiment.log(grids)
|
| 75 |
-
|
| 76 |
-
def log_local(self,
|
| 77 |
-
outputs: List[List['Latent2MeshOutput']],
|
| 78 |
-
images: Union[np.ndarray, List[np.ndarray]],
|
| 79 |
-
description: List[str],
|
| 80 |
-
keys: List[str],
|
| 81 |
-
save_dir: str, split: str,
|
| 82 |
-
global_step: int, current_epoch: int, batch_idx: int,
|
| 83 |
-
prog_bar: bool = False,
|
| 84 |
-
multi_views=None, # yf ...
|
| 85 |
-
) -> None:
|
| 86 |
-
|
| 87 |
-
folder = "gs-{:010}_e-{:06}_b-{:06}".format(global_step, current_epoch, batch_idx)
|
| 88 |
-
visual_dir = os.path.join(save_dir, "visuals", split, folder)
|
| 89 |
-
os.makedirs(visual_dir, exist_ok=True)
|
| 90 |
-
|
| 91 |
-
num_samples = len(images)
|
| 92 |
-
|
| 93 |
-
for i in range(num_samples):
|
| 94 |
-
key_i = keys[i]
|
| 95 |
-
image_i = self.denormalize_image(images[i])
|
| 96 |
-
shape_tag_i = description[i]
|
| 97 |
-
|
| 98 |
-
for j in range(1):
|
| 99 |
-
mesh = outputs[j][i]
|
| 100 |
-
if mesh is None:
|
| 101 |
-
continue
|
| 102 |
-
|
| 103 |
-
mesh_v = mesh.mesh_v.copy()
|
| 104 |
-
mesh_v[:, 0] += j * np.max(self.bbox_size)
|
| 105 |
-
self.viewer.add_mesh(mesh_v, mesh.mesh_f)
|
| 106 |
-
|
| 107 |
-
image_tag = html_util.to_image_embed_tag(image_i)
|
| 108 |
-
mesh_tag = self.viewer.to_html(html_frame=False)
|
| 109 |
-
|
| 110 |
-
table_tag = f"""
|
| 111 |
-
<table border = "1">
|
| 112 |
-
<caption> {shape_tag_i} - {key_i} </caption>
|
| 113 |
-
<caption> Input Image | Generated Mesh </caption>
|
| 114 |
-
<tr>
|
| 115 |
-
<td>{image_tag}</td>
|
| 116 |
-
<td>{mesh_tag}</td>
|
| 117 |
-
</tr>
|
| 118 |
-
</table>
|
| 119 |
-
"""
|
| 120 |
-
|
| 121 |
-
if multi_views is not None:
|
| 122 |
-
multi_views_i = self.make_grid(multi_views[i])
|
| 123 |
-
views_tag = html_util.to_image_embed_tag(self.denormalize_image(multi_views_i))
|
| 124 |
-
table_tag = f"""
|
| 125 |
-
<table border = "1">
|
| 126 |
-
<caption> {shape_tag_i} - {key_i} </caption>
|
| 127 |
-
<caption> Input Image | Generated Mesh </caption>
|
| 128 |
-
<tr>
|
| 129 |
-
<td>{image_tag}</td>
|
| 130 |
-
<td>{views_tag}</td>
|
| 131 |
-
<td>{mesh_tag}</td>
|
| 132 |
-
</tr>
|
| 133 |
-
</table>
|
| 134 |
-
"""
|
| 135 |
-
|
| 136 |
-
html_frame = html_util.to_html_frame(table_tag)
|
| 137 |
-
if len(key_i) > 100:
|
| 138 |
-
key_i = key_i[:100]
|
| 139 |
-
with open(os.path.join(visual_dir, f"{key_i}.html"), "w") as writer:
|
| 140 |
-
writer.write(html_frame)
|
| 141 |
-
|
| 142 |
-
self.viewer.reset()
|
| 143 |
-
|
| 144 |
-
def log_sample(self,
|
| 145 |
-
pl_module: pl.LightningModule,
|
| 146 |
-
batch: Dict[str, torch.FloatTensor],
|
| 147 |
-
batch_idx: int,
|
| 148 |
-
split: str = "train") -> None:
|
| 149 |
-
"""
|
| 150 |
-
|
| 151 |
-
Args:
|
| 152 |
-
pl_module:
|
| 153 |
-
batch (dict): the batch sample information, and it contains:
|
| 154 |
-
- surface (torch.FloatTensor):
|
| 155 |
-
- image (torch.FloatTensor):
|
| 156 |
-
batch_idx (int):
|
| 157 |
-
split (str):
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
|
| 161 |
-
"""
|
| 162 |
-
|
| 163 |
-
is_train = pl_module.training
|
| 164 |
-
if is_train:
|
| 165 |
-
pl_module.eval()
|
| 166 |
-
|
| 167 |
-
batch_size = len(batch["surface"])
|
| 168 |
-
replace = batch_size < self.num_samples
|
| 169 |
-
ids = np.random.choice(batch_size, self.num_samples, replace=replace)
|
| 170 |
-
|
| 171 |
-
with torch.no_grad():
|
| 172 |
-
# run text to mesh
|
| 173 |
-
# keys = [batch["__key__"][i] for i in ids]
|
| 174 |
-
keys = [f'key_{i}' for i in ids]
|
| 175 |
-
# texts = [batch["text"][i] for i in ids]
|
| 176 |
-
texts = [f'text_{i}'for i in ids]
|
| 177 |
-
# description = [batch["description"][i] for i in ids]
|
| 178 |
-
description = [f'desc_{i}_{os.path.splitext(os.path.basename(batch["uid"][i]))[0]}' for i in ids]
|
| 179 |
-
images = batch["image"][ids]
|
| 180 |
-
mask_input = batch["mask"][ids] if 'mask' in batch else None
|
| 181 |
-
# uids = batch["uid"][ids]
|
| 182 |
-
sample_batch = {
|
| 183 |
-
"__key__": keys,
|
| 184 |
-
"image": images,
|
| 185 |
-
'text': texts,
|
| 186 |
-
'mask': mask_input,
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
# if 'cam_parm' in batch:
|
| 190 |
-
# sample_batch['cam_parm'] = batch['cam_parm'][ids]
|
| 191 |
-
|
| 192 |
-
# if 'multi_views' in batch: # yf ...
|
| 193 |
-
# sample_batch['multi_views'] = batch['multi_views'][ids]
|
| 194 |
-
|
| 195 |
-
outputs = pl_module.sample(
|
| 196 |
-
batch=sample_batch,
|
| 197 |
-
output_type='latents2mesh'
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
images = images.cpu().float().numpy()
|
| 201 |
-
# images = self.denormalize_image(images)
|
| 202 |
-
# images = np.transpose(images, (0, 2, 3, 1))
|
| 203 |
-
# images = ((images + 1) / 2 * 255).astype(np.uint8)
|
| 204 |
-
|
| 205 |
-
self.log_local(outputs, images, description, keys, pl_module.logger.save_dir, split,
|
| 206 |
-
pl_module.global_step, pl_module.current_epoch, batch_idx, prog_bar=False,
|
| 207 |
-
multi_views=sample_batch.get('multi_views'))
|
| 208 |
-
|
| 209 |
-
if is_train: pl_module.train()
|
| 210 |
-
|
| 211 |
-
def make_grid(self, images): # return (3,h,w) in (0,1) ...
|
| 212 |
-
images_resized = []
|
| 213 |
-
for img in images:
|
| 214 |
-
img_resized = torchvision.transforms.functional.resize(img, (320, 320))
|
| 215 |
-
images_resized.append(img_resized)
|
| 216 |
-
image = torchvision.utils.make_grid(images_resized, nrow=2, padding=5, pad_value=255)
|
| 217 |
-
|
| 218 |
-
image = image.cpu().numpy()
|
| 219 |
-
# image = np.transpose(image, (1, 2, 0))
|
| 220 |
-
# image = (image * 255).astype(np.uint8)
|
| 221 |
-
|
| 222 |
-
return image
|
| 223 |
-
|
| 224 |
-
def check_frequency(self, step: int) -> bool:
|
| 225 |
-
if step % self.step_freq == 0:
|
| 226 |
-
return True
|
| 227 |
-
return False
|
| 228 |
-
|
| 229 |
-
def on_train_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
| 230 |
-
outputs: Generic, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> None:
|
| 231 |
-
|
| 232 |
-
if (self.check_frequency(pl_module.global_step) and # batch_idx % self.batch_freq == 0
|
| 233 |
-
hasattr(pl_module, "sample") and
|
| 234 |
-
callable(pl_module.sample) and
|
| 235 |
-
self.num_samples > 0):
|
| 236 |
-
self.log_sample(pl_module, batch, batch_idx, split="train")
|
| 237 |
-
self.has_train_logged = True
|
| 238 |
-
|
| 239 |
-
def on_validation_batch_end(self, trainer: pl.trainer.Trainer, pl_module: pl.LightningModule,
|
| 240 |
-
outputs: Generic, batch: Dict[str, torch.FloatTensor],
|
| 241 |
-
dataloader_idx: int, batch_idx: int) -> None:
|
| 242 |
-
|
| 243 |
-
if self.has_train_logged:
|
| 244 |
-
self.log_sample(pl_module, batch, batch_idx, split="val")
|
| 245 |
-
self.has_train_logged = False
|
| 246 |
-
|
| 247 |
-
def denormalize_image(self, image):
|
| 248 |
-
"""
|
| 249 |
-
|
| 250 |
-
Args:
|
| 251 |
-
image (np.ndarray): [3, h, w]
|
| 252 |
-
|
| 253 |
-
Returns:
|
| 254 |
-
image (np.ndarray): [h, w, 3], np.uint8, [0, 255].
|
| 255 |
-
"""
|
| 256 |
-
# image = np.transpose(image, (0, 2, 3, 1))
|
| 257 |
-
image = np.transpose(image, (1, 2, 0))
|
| 258 |
-
|
| 259 |
-
if self.std is not None:
|
| 260 |
-
image = image * self.std
|
| 261 |
-
|
| 262 |
-
if self.mean is not None:
|
| 263 |
-
image = image + self.mean
|
| 264 |
-
|
| 265 |
-
image = (image * 255).astype(np.uint8)
|
| 266 |
-
|
| 267 |
-
return image
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
class ImageConditionalFixASLDiffuserLogger(Callback):
|
| 271 |
-
def __init__(
|
| 272 |
-
self,
|
| 273 |
-
step_frequency: int,
|
| 274 |
-
test_data_path: str,
|
| 275 |
-
max_size: int = None,
|
| 276 |
-
save_dir: str = 'infer',
|
| 277 |
-
**kwargs,
|
| 278 |
-
) -> None:
|
| 279 |
-
super().__init__()
|
| 280 |
-
self.step_freq = step_frequency
|
| 281 |
-
self.viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
|
| 282 |
-
|
| 283 |
-
self.test_data_path = test_data_path
|
| 284 |
-
with open(self.test_data_path, 'r') as f:
|
| 285 |
-
data = json.load(f)
|
| 286 |
-
self.file_list = data['file_list']
|
| 287 |
-
# self.file_folder = data['file_folder']
|
| 288 |
-
if max_size is not None:
|
| 289 |
-
self.file_list = self.file_list[:max_size]
|
| 290 |
-
self.kwargs = kwargs
|
| 291 |
-
self.save_dir = save_dir
|
| 292 |
-
|
| 293 |
-
def on_train_batch_end(
|
| 294 |
-
self,
|
| 295 |
-
trainer: pl.trainer.Trainer,
|
| 296 |
-
pl_module: pl.LightningModule,
|
| 297 |
-
outputs: Generic,
|
| 298 |
-
batch: Dict[str, torch.FloatTensor],
|
| 299 |
-
batch_idx: int,
|
| 300 |
-
):
|
| 301 |
-
if pl_module.global_step % self.step_freq == 0:
|
| 302 |
-
with open(self.test_data_path, 'r') as f:
|
| 303 |
-
data = json.load(f)
|
| 304 |
-
self.file_list = data['file_list']
|
| 305 |
-
is_train = pl_module.training
|
| 306 |
-
if is_train:
|
| 307 |
-
pl_module.eval()
|
| 308 |
-
|
| 309 |
-
# folder_path = self.file_folder
|
| 310 |
-
# folder_name = os.path.basename(folder_path)
|
| 311 |
-
folder = "gs-{:010}_e-{:06}_b-{:06}".format(pl_module.global_step, pl_module.current_epoch, batch_idx)
|
| 312 |
-
visual_dir = os.path.join(pl_module.logger.save_dir, self.save_dir, folder)
|
| 313 |
-
os.makedirs(visual_dir, exist_ok=True)
|
| 314 |
-
|
| 315 |
-
image_paths = self.file_list
|
| 316 |
-
chunk_size = math.ceil(len(image_paths) / trainer.world_size)
|
| 317 |
-
if pl_module.global_rank == trainer.world_size - 1:
|
| 318 |
-
image_paths = image_paths[pl_module.global_rank * chunk_size:]
|
| 319 |
-
else:
|
| 320 |
-
image_paths = image_paths[pl_module.global_rank * chunk_size:(pl_module.global_rank + 1) * chunk_size]
|
| 321 |
-
|
| 322 |
-
print(f'Rank{pl_module.global_rank}: processing {len(image_paths)}|{len(self.file_list)} images')
|
| 323 |
-
for image_path in image_paths:
|
| 324 |
-
# if folder_path in image_path:
|
| 325 |
-
# save_path = image_path.replace(folder_path, visual_dir)
|
| 326 |
-
# else:
|
| 327 |
-
save_path = os.path.join(visual_dir, os.path.basename(image_path))
|
| 328 |
-
save_path = os.path.splitext(save_path)[0] + '.glb'
|
| 329 |
-
|
| 330 |
-
if isinstance(image_path, str):
|
| 331 |
-
print(image_path)
|
| 332 |
-
|
| 333 |
-
with torch.no_grad():
|
| 334 |
-
mesh = pl_module.sample(batch={"image": image_path}, **self.kwargs)[0][0]
|
| 335 |
-
if isinstance(mesh, tuple) and len(mesh)==2:
|
| 336 |
-
mesh = export_to_trimesh(mesh)
|
| 337 |
-
elif isinstance(mesh, trimesh.Trimesh):
|
| 338 |
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 339 |
-
mesh.export(save_path)
|
| 340 |
-
|
| 341 |
-
if is_train:
|
| 342 |
-
pl_module.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/trainings/peft.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 4 |
-
# except for the third-party components listed below.
|
| 5 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 6 |
-
# in the repsective licenses of these third-party components.
|
| 7 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 8 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 9 |
-
# all relevant laws and regulations.
|
| 10 |
-
|
| 11 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 12 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 13 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 14 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 15 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 16 |
-
|
| 17 |
-
import os
|
| 18 |
-
from pytorch_lightning.callbacks import Callback
|
| 19 |
-
from omegaconf import OmegaConf, ListConfig
|
| 20 |
-
|
| 21 |
-
class PeftSaveCallback(Callback):
|
| 22 |
-
def __init__(self, peft_model, save_dir: str, save_every_n_steps: int = None):
|
| 23 |
-
super().__init__()
|
| 24 |
-
self.peft_model = peft_model
|
| 25 |
-
self.save_dir = save_dir
|
| 26 |
-
self.save_every_n_steps = save_every_n_steps
|
| 27 |
-
os.makedirs(self.save_dir, exist_ok=True)
|
| 28 |
-
|
| 29 |
-
def recursive_convert(self, obj):
|
| 30 |
-
from omegaconf import OmegaConf, ListConfig
|
| 31 |
-
if isinstance(obj, (OmegaConf, ListConfig)):
|
| 32 |
-
return OmegaConf.to_container(obj, resolve=True)
|
| 33 |
-
elif isinstance(obj, dict):
|
| 34 |
-
return {k: self.recursive_convert(v) for k, v in obj.items()}
|
| 35 |
-
elif isinstance(obj, list):
|
| 36 |
-
return [self.recursive_convert(i) for i in obj]
|
| 37 |
-
elif isinstance(obj, type):
|
| 38 |
-
# 避免修改类对象
|
| 39 |
-
return obj
|
| 40 |
-
elif hasattr(obj, '__dict__'):
|
| 41 |
-
for attr_name, attr_value in vars(obj).items():
|
| 42 |
-
setattr(obj, attr_name, self.recursive_convert(attr_value))
|
| 43 |
-
return obj
|
| 44 |
-
else:
|
| 45 |
-
return obj
|
| 46 |
-
|
| 47 |
-
# def recursive_convert(self, obj):
|
| 48 |
-
# if isinstance(obj, (OmegaConf, ListConfig)):
|
| 49 |
-
# return OmegaConf.to_container(obj, resolve=True)
|
| 50 |
-
# elif isinstance(obj, dict):
|
| 51 |
-
# return {k: self.recursive_convert(v) for k, v in obj.items()}
|
| 52 |
-
# elif isinstance(obj, list):
|
| 53 |
-
# return [self.recursive_convert(i) for i in obj]
|
| 54 |
-
# elif hasattr(obj, '__dict__'):
|
| 55 |
-
# for attr_name, attr_value in vars(obj).items():
|
| 56 |
-
# setattr(obj, attr_name, self.recursive_convert(attr_value))
|
| 57 |
-
# return obj
|
| 58 |
-
# else:
|
| 59 |
-
# return obj
|
| 60 |
-
|
| 61 |
-
def _convert_peft_config(self):
|
| 62 |
-
pc = self.peft_model.peft_config
|
| 63 |
-
self.peft_model.peft_config = self.recursive_convert(pc)
|
| 64 |
-
|
| 65 |
-
def on_train_epoch_end(self, trainer, pl_module):
|
| 66 |
-
self._convert_peft_config()
|
| 67 |
-
save_path = os.path.join(self.save_dir, f"epoch_{trainer.current_epoch}")
|
| 68 |
-
self.peft_model.save_pretrained(save_path)
|
| 69 |
-
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
| 70 |
-
|
| 71 |
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
| 72 |
-
if self.save_every_n_steps is not None:
|
| 73 |
-
global_step = trainer.global_step
|
| 74 |
-
if global_step % self.save_every_n_steps == 0 and global_step > 0:
|
| 75 |
-
self._convert_peft_config()
|
| 76 |
-
save_path = os.path.join(self.save_dir, f"step_{global_step}")
|
| 77 |
-
self.peft_model.save_pretrained(save_path)
|
| 78 |
-
print(f"[PeftSaveCallback] Saved LoRA weights to {save_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/typing.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
This module contains type annotations for the project, using
|
| 3 |
-
1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
|
| 4 |
-
2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
|
| 5 |
-
|
| 6 |
-
Two types of typing checking can be used:
|
| 7 |
-
1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
|
| 8 |
-
2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
# Basic types
|
| 12 |
-
from typing import (
|
| 13 |
-
Any,
|
| 14 |
-
Callable,
|
| 15 |
-
Dict,
|
| 16 |
-
Iterable,
|
| 17 |
-
List,
|
| 18 |
-
Literal,
|
| 19 |
-
NamedTuple,
|
| 20 |
-
NewType,
|
| 21 |
-
Optional,
|
| 22 |
-
Sized,
|
| 23 |
-
Tuple,
|
| 24 |
-
Type,
|
| 25 |
-
TypeVar,
|
| 26 |
-
Union,
|
| 27 |
-
Sequence,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# Tensor dtype
|
| 31 |
-
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
| 32 |
-
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
| 33 |
-
|
| 34 |
-
# Config type
|
| 35 |
-
from omegaconf import DictConfig
|
| 36 |
-
|
| 37 |
-
# PyTorch Tensor type
|
| 38 |
-
from torch import Tensor
|
| 39 |
-
|
| 40 |
-
# Runtime type checking decorator
|
| 41 |
-
from typeguard import typechecked as typechecker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/utils.py
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import logging
|
| 16 |
-
import os
|
| 17 |
-
from functools import wraps
|
| 18 |
-
|
| 19 |
-
import torch
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def get_logger(name):
|
| 23 |
-
logger = logging.getLogger(name)
|
| 24 |
-
logger.setLevel(logging.INFO)
|
| 25 |
-
|
| 26 |
-
console_handler = logging.StreamHandler()
|
| 27 |
-
console_handler.setLevel(logging.INFO)
|
| 28 |
-
|
| 29 |
-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 30 |
-
console_handler.setFormatter(formatter)
|
| 31 |
-
logger.addHandler(console_handler)
|
| 32 |
-
return logger
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
logger = get_logger('hy3dgen.shapgen')
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class synchronize_timer:
|
| 39 |
-
""" Synchronized timer to count the inference time of `nn.Module.forward`.
|
| 40 |
-
|
| 41 |
-
Supports both context manager and decorator usage.
|
| 42 |
-
|
| 43 |
-
Example as context manager:
|
| 44 |
-
```python
|
| 45 |
-
with synchronize_timer('name') as t:
|
| 46 |
-
run()
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
Example as decorator:
|
| 50 |
-
```python
|
| 51 |
-
@synchronize_timer('Export to trimesh')
|
| 52 |
-
def export_to_trimesh(mesh_output):
|
| 53 |
-
pass
|
| 54 |
-
```
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(self, name=None):
|
| 58 |
-
self.name = name
|
| 59 |
-
|
| 60 |
-
def __enter__(self):
|
| 61 |
-
"""Context manager entry: start timing."""
|
| 62 |
-
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
| 63 |
-
self.start = torch.cuda.Event(enable_timing=True)
|
| 64 |
-
self.end = torch.cuda.Event(enable_timing=True)
|
| 65 |
-
self.start.record()
|
| 66 |
-
return lambda: self.time
|
| 67 |
-
|
| 68 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 69 |
-
"""Context manager exit: stop timing and log results."""
|
| 70 |
-
if os.environ.get('HY3DGEN_DEBUG', '0') == '1':
|
| 71 |
-
self.end.record()
|
| 72 |
-
torch.cuda.synchronize()
|
| 73 |
-
self.time = self.start.elapsed_time(self.end)
|
| 74 |
-
if self.name is not None:
|
| 75 |
-
logger.info(f'{self.name} takes {self.time} ms')
|
| 76 |
-
|
| 77 |
-
def __call__(self, func):
|
| 78 |
-
"""Decorator: wrap the function to time its execution."""
|
| 79 |
-
|
| 80 |
-
@wraps(func)
|
| 81 |
-
def wrapper(*args, **kwargs):
|
| 82 |
-
with self:
|
| 83 |
-
result = func(*args, **kwargs)
|
| 84 |
-
return result
|
| 85 |
-
|
| 86 |
-
return wrapper
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def smart_load_model(
|
| 90 |
-
model_path,
|
| 91 |
-
subfolder,
|
| 92 |
-
use_safetensors,
|
| 93 |
-
variant,
|
| 94 |
-
):
|
| 95 |
-
original_model_path = model_path
|
| 96 |
-
# try local path
|
| 97 |
-
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
| 98 |
-
model_fld = os.path.expanduser(os.path.join(base_dir, model_path))
|
| 99 |
-
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
| 100 |
-
logger.info(f'Try to load model from local path: {model_path}')
|
| 101 |
-
if not os.path.exists(model_path):
|
| 102 |
-
logger.info('Model path not exists, try to download from huggingface')
|
| 103 |
-
try:
|
| 104 |
-
from huggingface_hub import snapshot_download
|
| 105 |
-
# 只下载指定子目录
|
| 106 |
-
path = snapshot_download(
|
| 107 |
-
repo_id=original_model_path,
|
| 108 |
-
allow_patterns=[f"{subfolder}/*"], # 关键修改:模式匹配子文件夹
|
| 109 |
-
local_dir=model_fld
|
| 110 |
-
)
|
| 111 |
-
model_path = os.path.join(path, subfolder) # 保持路径拼接逻辑不变
|
| 112 |
-
except ImportError:
|
| 113 |
-
logger.warning(
|
| 114 |
-
"You need to install HuggingFace Hub to load models from the hub."
|
| 115 |
-
)
|
| 116 |
-
raise RuntimeError(f"Model path {model_path} not found")
|
| 117 |
-
except Exception as e:
|
| 118 |
-
raise e
|
| 119 |
-
|
| 120 |
-
if not os.path.exists(model_path):
|
| 121 |
-
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
| 122 |
-
|
| 123 |
-
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
| 124 |
-
variant = '' if variant is None else f'.{variant}'
|
| 125 |
-
ckpt_name = f'model{variant}.{extension}'
|
| 126 |
-
config_path = os.path.join(model_path, 'config.yaml')
|
| 127 |
-
ckpt_path = os.path.join(model_path, ckpt_name)
|
| 128 |
-
return config_path, ckpt_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/visualizers/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
|
|
|
|
|
ultrashape/utils/visualizers/color_util.py
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import matplotlib.pyplot as plt
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# Helper functions
|
| 20 |
-
def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None):
|
| 21 |
-
colormap = plt.cm.get_cmap(colormap)
|
| 22 |
-
if normalize:
|
| 23 |
-
vmin = np.min(inp)
|
| 24 |
-
vmax = np.max(inp)
|
| 25 |
-
|
| 26 |
-
norm = plt.Normalize(vmin, vmax)
|
| 27 |
-
return colormap(norm(inp))[:, :3]
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256):
|
| 31 |
-
# tex dims need to be power of two.
|
| 32 |
-
array = np.ones((width, height, 3), dtype='float32')
|
| 33 |
-
|
| 34 |
-
# width in texels of each checker
|
| 35 |
-
checker_w = width / n_checkers_x
|
| 36 |
-
checker_h = height / n_checkers_y
|
| 37 |
-
|
| 38 |
-
for y in range(height):
|
| 39 |
-
for x in range(width):
|
| 40 |
-
color_key = int(x / checker_w) + int(y / checker_h)
|
| 41 |
-
if color_key % 2 == 0:
|
| 42 |
-
array[x, y, :] = [1., 0.874, 0.0]
|
| 43 |
-
else:
|
| 44 |
-
array[x, y, :] = [0., 0., 0.]
|
| 45 |
-
return array
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def gen_circle(width=256, height=256):
|
| 49 |
-
xx, yy = np.mgrid[:width, :height]
|
| 50 |
-
circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2
|
| 51 |
-
array = np.ones((width, height, 4), dtype='float32')
|
| 52 |
-
array[:, :, 0] = (circle <= width)
|
| 53 |
-
array[:, :, 1] = (circle <= width)
|
| 54 |
-
array[:, :, 2] = (circle <= width)
|
| 55 |
-
array[:, :, 3] = circle <= width
|
| 56 |
-
return array
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/visualizers/html_util.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 4 |
-
# except for the third-party components listed below.
|
| 5 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 6 |
-
# in the repsective licenses of these third-party components.
|
| 7 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 8 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 9 |
-
# all relevant laws and regulations.
|
| 10 |
-
|
| 11 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 12 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 13 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 14 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 15 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 16 |
-
|
| 17 |
-
import io
|
| 18 |
-
import base64
|
| 19 |
-
import numpy as np
|
| 20 |
-
from PIL import Image
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def to_html_frame(content):
|
| 24 |
-
|
| 25 |
-
html_frame = f"""
|
| 26 |
-
<html>
|
| 27 |
-
<body>
|
| 28 |
-
{content}
|
| 29 |
-
</body>
|
| 30 |
-
</html>
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
return html_frame
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def to_single_row_table(caption: str, content: str):
|
| 37 |
-
|
| 38 |
-
table_html = f"""
|
| 39 |
-
<table border = "1">
|
| 40 |
-
<caption>{caption}</caption>
|
| 41 |
-
<tr>
|
| 42 |
-
<td>{content}</td>
|
| 43 |
-
</tr>
|
| 44 |
-
</table>
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
return table_html
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def to_image_embed_tag(image: np.ndarray):
|
| 51 |
-
|
| 52 |
-
# Convert np.ndarray to bytes
|
| 53 |
-
img = Image.fromarray(image)
|
| 54 |
-
raw_bytes = io.BytesIO()
|
| 55 |
-
img.save(raw_bytes, "PNG")
|
| 56 |
-
|
| 57 |
-
# Encode bytes to base64
|
| 58 |
-
image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8")
|
| 59 |
-
|
| 60 |
-
image_tag = f"""
|
| 61 |
-
<img src="data:image/png;base64,{image_base64}" alt="Embedded Image">
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
return image_tag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/visualizers/pythreejs_viewer.py
DELETED
|
@@ -1,549 +0,0 @@
|
|
| 1 |
-
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
-
# except for the third-party components listed below.
|
| 3 |
-
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
-
# in the repsective licenses of these third-party components.
|
| 5 |
-
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
-
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
-
# all relevant laws and regulations.
|
| 8 |
-
|
| 9 |
-
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
-
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
-
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
-
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
-
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import numpy as np
|
| 17 |
-
from ipywidgets import embed
|
| 18 |
-
import pythreejs as p3s
|
| 19 |
-
import uuid
|
| 20 |
-
|
| 21 |
-
from .color_util import get_colors, gen_circle, gen_checkers
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class PyThreeJSViewer(object):
|
| 28 |
-
|
| 29 |
-
def __init__(self, settings, render_mode="WEBSITE"):
|
| 30 |
-
self.render_mode = render_mode
|
| 31 |
-
self.__update_settings(settings)
|
| 32 |
-
self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6)
|
| 33 |
-
self._light2 = p3s.AmbientLight(intensity=0.5)
|
| 34 |
-
self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"],
|
| 35 |
-
aspect=self.__s["width"] / self.__s["height"], children=[self._light])
|
| 36 |
-
self._orbit = p3s.OrbitControls(controlling=self._cam)
|
| 37 |
-
self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80"
|
| 38 |
-
self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit],
|
| 39 |
-
width=self.__s["width"], height=self.__s["height"],
|
| 40 |
-
antialias=self.__s["antialias"])
|
| 41 |
-
|
| 42 |
-
self.__objects = {}
|
| 43 |
-
self.__cnt = 0
|
| 44 |
-
|
| 45 |
-
def jupyter_mode(self):
|
| 46 |
-
self.render_mode = "JUPYTER"
|
| 47 |
-
|
| 48 |
-
def offline(self):
|
| 49 |
-
self.render_mode = "OFFLINE"
|
| 50 |
-
|
| 51 |
-
def website(self):
|
| 52 |
-
self.render_mode = "WEBSITE"
|
| 53 |
-
|
| 54 |
-
def __get_shading(self, shading):
|
| 55 |
-
shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black",
|
| 56 |
-
"side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None],
|
| 57 |
-
"bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0,
|
| 58 |
-
"line_width": 1.0, "line_color": "black",
|
| 59 |
-
"point_color": "red", "point_size": 0.01, "point_shape": "circle",
|
| 60 |
-
"text_color": "red"
|
| 61 |
-
}
|
| 62 |
-
for k in shading:
|
| 63 |
-
shad[k] = shading[k]
|
| 64 |
-
return shad
|
| 65 |
-
|
| 66 |
-
def __update_settings(self, settings={}):
|
| 67 |
-
sett = {"width": 1600, "height": 800, "antialias": True, "scale": 1.5, "background": "#ffffff",
|
| 68 |
-
"fov": 30}
|
| 69 |
-
for k in settings:
|
| 70 |
-
sett[k] = settings[k]
|
| 71 |
-
self.__s = sett
|
| 72 |
-
|
| 73 |
-
def __add_object(self, obj, parent=None):
|
| 74 |
-
if not parent: # Object is added to global scene and objects dict
|
| 75 |
-
self.__objects[self.__cnt] = obj
|
| 76 |
-
self.__cnt += 1
|
| 77 |
-
self._scene.add(obj["mesh"])
|
| 78 |
-
else: # Object is added to parent object and NOT to objects dict
|
| 79 |
-
parent.add(obj["mesh"])
|
| 80 |
-
|
| 81 |
-
self.__update_view()
|
| 82 |
-
|
| 83 |
-
if self.render_mode == "JUPYTER":
|
| 84 |
-
return self.__cnt - 1
|
| 85 |
-
elif self.render_mode == "WEBSITE":
|
| 86 |
-
return self
|
| 87 |
-
|
| 88 |
-
def __add_line_geometry(self, lines, shading, obj=None):
|
| 89 |
-
lines = lines.astype("float32", copy=False)
|
| 90 |
-
mi = np.min(lines, axis=0)
|
| 91 |
-
ma = np.max(lines, axis=0)
|
| 92 |
-
|
| 93 |
-
geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3)))
|
| 94 |
-
material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"])
|
| 95 |
-
# , vertexColors='VertexColors'),
|
| 96 |
-
lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces')
|
| 97 |
-
line_obj = {"geometry": geometry, "mesh": lines, "material": material,
|
| 98 |
-
"max": ma, "min": mi, "type": "Lines", "wireframe": None}
|
| 99 |
-
|
| 100 |
-
if obj:
|
| 101 |
-
return self.__add_object(line_obj, obj), line_obj
|
| 102 |
-
else:
|
| 103 |
-
return self.__add_object(line_obj)
|
| 104 |
-
|
| 105 |
-
def __update_view(self):
|
| 106 |
-
if len(self.__objects) == 0:
|
| 107 |
-
return
|
| 108 |
-
ma = np.zeros((len(self.__objects), 3))
|
| 109 |
-
mi = np.zeros((len(self.__objects), 3))
|
| 110 |
-
for r, obj in enumerate(self.__objects):
|
| 111 |
-
ma[r] = self.__objects[obj]["max"]
|
| 112 |
-
mi[r] = self.__objects[obj]["min"]
|
| 113 |
-
ma = np.max(ma, axis=0)
|
| 114 |
-
mi = np.min(mi, axis=0)
|
| 115 |
-
diag = np.linalg.norm(ma - mi)
|
| 116 |
-
mean = ((ma - mi) / 2 + mi).tolist()
|
| 117 |
-
scale = self.__s["scale"] * (diag)
|
| 118 |
-
self._orbit.target = mean
|
| 119 |
-
self._cam.lookAt(mean)
|
| 120 |
-
self._cam.position = [mean[0], mean[1], mean[2] + scale]
|
| 121 |
-
self._light.position = [mean[0], mean[1], mean[2] + scale]
|
| 122 |
-
|
| 123 |
-
self._orbit.exec_three_obj_method('update')
|
| 124 |
-
self._cam.exec_three_obj_method('updateProjectionMatrix')
|
| 125 |
-
|
| 126 |
-
def __get_bbox(self, v):
|
| 127 |
-
m = np.min(v, axis=0)
|
| 128 |
-
M = np.max(v, axis=0)
|
| 129 |
-
|
| 130 |
-
# Corners of the bounding box
|
| 131 |
-
v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]],
|
| 132 |
-
[m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]])
|
| 133 |
-
|
| 134 |
-
f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4],
|
| 135 |
-
[0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32)
|
| 136 |
-
return v_box, f_box
|
| 137 |
-
|
| 138 |
-
def __get_colors(self, v, f, c, sh):
|
| 139 |
-
coloring = "VertexColors"
|
| 140 |
-
if type(c) == np.ndarray and c.size == 3: # Single color
|
| 141 |
-
colors = np.ones_like(v)
|
| 142 |
-
colors[:, 0] = c[0]
|
| 143 |
-
colors[:, 1] = c[1]
|
| 144 |
-
colors[:, 2] = c[2]
|
| 145 |
-
# print("Single colors")
|
| 146 |
-
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for
|
| 147 |
-
if c.shape[0] == f.shape[0]: # faces
|
| 148 |
-
colors = np.hstack([c, c, c]).reshape((-1, 3))
|
| 149 |
-
coloring = "FaceColors"
|
| 150 |
-
# print("Face color values")
|
| 151 |
-
elif c.shape[0] == v.shape[0]: # vertices
|
| 152 |
-
colors = c
|
| 153 |
-
# print("Vertex color values")
|
| 154 |
-
else: # Wrong size, fallback
|
| 155 |
-
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
| 156 |
-
colors = np.ones_like(v)
|
| 157 |
-
colors[:, 0] = 1.0
|
| 158 |
-
colors[:, 1] = 0.874
|
| 159 |
-
colors[:, 2] = 0.0
|
| 160 |
-
elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces
|
| 161 |
-
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
| 162 |
-
cc = get_colors(c, sh["colormap"], normalize=normalize,
|
| 163 |
-
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
| 164 |
-
# print(cc.shape)
|
| 165 |
-
colors = np.hstack([cc, cc, cc]).reshape((-1, 3))
|
| 166 |
-
coloring = "FaceColors"
|
| 167 |
-
# print("Face function values")
|
| 168 |
-
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices
|
| 169 |
-
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
| 170 |
-
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
| 171 |
-
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
| 172 |
-
# print("Vertex function values")
|
| 173 |
-
|
| 174 |
-
else:
|
| 175 |
-
colors = np.ones_like(v)
|
| 176 |
-
colors[:, 0] = 1.0
|
| 177 |
-
colors[:, 1] = 0.874
|
| 178 |
-
colors[:, 2] = 0.0
|
| 179 |
-
|
| 180 |
-
# No color
|
| 181 |
-
if c is not None:
|
| 182 |
-
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
| 183 |
-
|
| 184 |
-
return colors, coloring
|
| 185 |
-
|
| 186 |
-
def __get_point_colors(self, v, c, sh):
|
| 187 |
-
v_color = True
|
| 188 |
-
if c is None: # No color given, use global color
|
| 189 |
-
# conv = mpl.colors.ColorConverter()
|
| 190 |
-
colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"]))
|
| 191 |
-
v_color = False
|
| 192 |
-
elif isinstance(c, str): # No color given, use global color
|
| 193 |
-
# conv = mpl.colors.ColorConverter()
|
| 194 |
-
colors = c # np.array(conv.to_rgb(c))
|
| 195 |
-
v_color = False
|
| 196 |
-
elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3:
|
| 197 |
-
# Point color
|
| 198 |
-
colors = c.astype("float32", copy=False)
|
| 199 |
-
|
| 200 |
-
elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3:
|
| 201 |
-
# Function values for vertices, but the colors are features
|
| 202 |
-
c_norm = np.linalg.norm(c, ord=2, axis=-1)
|
| 203 |
-
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
| 204 |
-
colors = get_colors(c_norm, sh["colormap"], normalize=normalize,
|
| 205 |
-
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
| 206 |
-
colors = colors.astype("float32", copy=False)
|
| 207 |
-
|
| 208 |
-
elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color
|
| 209 |
-
normalize = sh["normalize"][0] != None and sh["normalize"][1] != None
|
| 210 |
-
colors = get_colors(c, sh["colormap"], normalize=normalize,
|
| 211 |
-
vmin=sh["normalize"][0], vmax=sh["normalize"][1])
|
| 212 |
-
colors = colors.astype("float32", copy=False)
|
| 213 |
-
# print("Vertex function values")
|
| 214 |
-
|
| 215 |
-
else:
|
| 216 |
-
print("Invalid color array given! Supported are numpy arrays.", type(c))
|
| 217 |
-
colors = sh["point_color"]
|
| 218 |
-
v_color = False
|
| 219 |
-
|
| 220 |
-
return colors, v_color
|
| 221 |
-
|
| 222 |
-
def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs):
|
| 223 |
-
shading.update(kwargs)
|
| 224 |
-
sh = self.__get_shading(shading)
|
| 225 |
-
mesh_obj = {}
|
| 226 |
-
|
| 227 |
-
# it is a tet
|
| 228 |
-
if v.shape[1] == 3 and f.shape[1] == 4:
|
| 229 |
-
f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype)
|
| 230 |
-
for i in range(f.shape[0]):
|
| 231 |
-
f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]])
|
| 232 |
-
f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]])
|
| 233 |
-
f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]])
|
| 234 |
-
f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]])
|
| 235 |
-
f = f_tmp
|
| 236 |
-
|
| 237 |
-
if v.shape[1] == 2:
|
| 238 |
-
v = np.append(v, np.zeros([v.shape[0], 1]), 1)
|
| 239 |
-
|
| 240 |
-
# Type adjustment vertices
|
| 241 |
-
v = v.astype("float32", copy=False)
|
| 242 |
-
|
| 243 |
-
# Color setup
|
| 244 |
-
colors, coloring = self.__get_colors(v, f, c, sh)
|
| 245 |
-
|
| 246 |
-
# Type adjustment faces and colors
|
| 247 |
-
c = colors.astype("float32", copy=False)
|
| 248 |
-
|
| 249 |
-
# Material and geometry setup
|
| 250 |
-
ba_dict = {"color": p3s.BufferAttribute(c)}
|
| 251 |
-
if coloring == "FaceColors":
|
| 252 |
-
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
| 253 |
-
for ii in range(f.shape[0]):
|
| 254 |
-
# print(ii*3, f[ii])
|
| 255 |
-
verts[ii * 3] = v[f[ii, 0]]
|
| 256 |
-
verts[ii * 3 + 1] = v[f[ii, 1]]
|
| 257 |
-
verts[ii * 3 + 2] = v[f[ii, 2]]
|
| 258 |
-
v = verts
|
| 259 |
-
else:
|
| 260 |
-
f = f.astype("uint32", copy=False).ravel()
|
| 261 |
-
ba_dict["index"] = p3s.BufferAttribute(f, normalized=False)
|
| 262 |
-
|
| 263 |
-
ba_dict["position"] = p3s.BufferAttribute(v, normalized=False)
|
| 264 |
-
|
| 265 |
-
if uv is not None:
|
| 266 |
-
uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv))
|
| 267 |
-
if texture_data is None:
|
| 268 |
-
texture_data = gen_checkers(20, 20)
|
| 269 |
-
tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType")
|
| 270 |
-
material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"],
|
| 271 |
-
roughness=sh["roughness"], metalness=sh["metalness"],
|
| 272 |
-
flatShading=sh["flat"],
|
| 273 |
-
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
| 274 |
-
ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False))
|
| 275 |
-
else:
|
| 276 |
-
material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"],
|
| 277 |
-
side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"],
|
| 278 |
-
flatShading=sh["flat"],
|
| 279 |
-
polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5)
|
| 280 |
-
|
| 281 |
-
if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well
|
| 282 |
-
ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True)
|
| 283 |
-
|
| 284 |
-
geometry = p3s.BufferGeometry(attributes=ba_dict)
|
| 285 |
-
|
| 286 |
-
if coloring == "VertexColors" and type(n) == type(None):
|
| 287 |
-
geometry.exec_three_obj_method('computeVertexNormals')
|
| 288 |
-
elif coloring == "FaceColors" and type(n) == type(None):
|
| 289 |
-
geometry.exec_three_obj_method('computeFaceNormals')
|
| 290 |
-
|
| 291 |
-
# Mesh setup
|
| 292 |
-
mesh = p3s.Mesh(geometry=geometry, material=material)
|
| 293 |
-
|
| 294 |
-
# Wireframe setup
|
| 295 |
-
mesh_obj["wireframe"] = None
|
| 296 |
-
if sh["wireframe"]:
|
| 297 |
-
wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry
|
| 298 |
-
wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"])
|
| 299 |
-
wireframe = p3s.LineSegments(wf_geometry, wf_material)
|
| 300 |
-
mesh.add(wireframe)
|
| 301 |
-
mesh_obj["wireframe"] = wireframe
|
| 302 |
-
|
| 303 |
-
# Bounding box setup
|
| 304 |
-
if sh["bbox"]:
|
| 305 |
-
v_box, f_box = self.__get_bbox(v)
|
| 306 |
-
_, bbox = self.add_edges(v_box, f_box, sh, mesh)
|
| 307 |
-
mesh_obj["bbox"] = [bbox, v_box, f_box]
|
| 308 |
-
|
| 309 |
-
# Object setup
|
| 310 |
-
mesh_obj["max"] = np.max(v, axis=0)
|
| 311 |
-
mesh_obj["min"] = np.min(v, axis=0)
|
| 312 |
-
mesh_obj["geometry"] = geometry
|
| 313 |
-
mesh_obj["mesh"] = mesh
|
| 314 |
-
mesh_obj["material"] = material
|
| 315 |
-
mesh_obj["type"] = "Mesh"
|
| 316 |
-
mesh_obj["shading"] = sh
|
| 317 |
-
mesh_obj["coloring"] = coloring
|
| 318 |
-
mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed
|
| 319 |
-
|
| 320 |
-
return self.__add_object(mesh_obj)
|
| 321 |
-
|
| 322 |
-
def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs):
|
| 323 |
-
shading.update(kwargs)
|
| 324 |
-
if len(beginning.shape) == 1:
|
| 325 |
-
if len(beginning) == 2:
|
| 326 |
-
beginning = np.array([[beginning[0], beginning[1], 0]])
|
| 327 |
-
else:
|
| 328 |
-
if beginning.shape[1] == 2:
|
| 329 |
-
beginning = np.append(
|
| 330 |
-
beginning, np.zeros([beginning.shape[0], 1]), 1)
|
| 331 |
-
if len(ending.shape) == 1:
|
| 332 |
-
if len(ending) == 2:
|
| 333 |
-
ending = np.array([[ending[0], ending[1], 0]])
|
| 334 |
-
else:
|
| 335 |
-
if ending.shape[1] == 2:
|
| 336 |
-
ending = np.append(
|
| 337 |
-
ending, np.zeros([ending.shape[0], 1]), 1)
|
| 338 |
-
|
| 339 |
-
sh = self.__get_shading(shading)
|
| 340 |
-
lines = np.hstack([beginning, ending])
|
| 341 |
-
lines = lines.reshape((-1, 3))
|
| 342 |
-
return self.__add_line_geometry(lines, sh, obj)
|
| 343 |
-
|
| 344 |
-
def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs):
|
| 345 |
-
shading.update(kwargs)
|
| 346 |
-
if vertices.shape[1] == 2:
|
| 347 |
-
vertices = np.append(
|
| 348 |
-
vertices, np.zeros([vertices.shape[0], 1]), 1)
|
| 349 |
-
sh = self.__get_shading(shading)
|
| 350 |
-
lines = np.zeros((edges.size, 3))
|
| 351 |
-
cnt = 0
|
| 352 |
-
for e in edges:
|
| 353 |
-
lines[cnt, :] = vertices[e[0]]
|
| 354 |
-
lines[cnt + 1, :] = vertices[e[1]]
|
| 355 |
-
cnt += 2
|
| 356 |
-
return self.__add_line_geometry(lines, sh, obj)
|
| 357 |
-
|
| 358 |
-
def add_points(self, points, c=None, shading={}, obj=None, **kwargs):
|
| 359 |
-
shading.update(kwargs)
|
| 360 |
-
if len(points.shape) == 1:
|
| 361 |
-
if len(points) == 2:
|
| 362 |
-
points = np.array([[points[0], points[1], 0]])
|
| 363 |
-
else:
|
| 364 |
-
if points.shape[1] == 2:
|
| 365 |
-
points = np.append(
|
| 366 |
-
points, np.zeros([points.shape[0], 1]), 1)
|
| 367 |
-
sh = self.__get_shading(shading)
|
| 368 |
-
points = points.astype("float32", copy=False)
|
| 369 |
-
mi = np.min(points, axis=0)
|
| 370 |
-
ma = np.max(points, axis=0)
|
| 371 |
-
|
| 372 |
-
g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)}
|
| 373 |
-
m_attributes = {"size": sh["point_size"]}
|
| 374 |
-
|
| 375 |
-
if sh["point_shape"] == "circle": # Plot circles
|
| 376 |
-
tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType")
|
| 377 |
-
m_attributes["map"] = tex
|
| 378 |
-
m_attributes["alphaTest"] = 0.5
|
| 379 |
-
m_attributes["transparency"] = True
|
| 380 |
-
else: # Plot squares
|
| 381 |
-
pass
|
| 382 |
-
|
| 383 |
-
colors, v_colors = self.__get_point_colors(points, c, sh)
|
| 384 |
-
if v_colors: # Colors per point
|
| 385 |
-
m_attributes["vertexColors"] = 'VertexColors'
|
| 386 |
-
g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False)
|
| 387 |
-
|
| 388 |
-
else: # Colors for all points
|
| 389 |
-
m_attributes["color"] = colors
|
| 390 |
-
|
| 391 |
-
material = p3s.PointsMaterial(**m_attributes)
|
| 392 |
-
geometry = p3s.BufferGeometry(attributes=g_attributes)
|
| 393 |
-
points = p3s.Points(geometry=geometry, material=material)
|
| 394 |
-
point_obj = {"geometry": geometry, "mesh": points, "material": material,
|
| 395 |
-
"max": ma, "min": mi, "type": "Points", "wireframe": None}
|
| 396 |
-
|
| 397 |
-
if obj:
|
| 398 |
-
return self.__add_object(point_obj, obj), point_obj
|
| 399 |
-
else:
|
| 400 |
-
return self.__add_object(point_obj)
|
| 401 |
-
|
| 402 |
-
def remove_object(self, obj_id):
|
| 403 |
-
if obj_id not in self.__objects:
|
| 404 |
-
print("Invalid object id. Valid ids are: ", list(self.__objects.keys()))
|
| 405 |
-
return
|
| 406 |
-
self._scene.remove(self.__objects[obj_id]["mesh"])
|
| 407 |
-
del self.__objects[obj_id]
|
| 408 |
-
self.__update_view()
|
| 409 |
-
|
| 410 |
-
def reset(self):
|
| 411 |
-
for obj_id in list(self.__objects.keys()).copy():
|
| 412 |
-
self._scene.remove(self.__objects[obj_id]["mesh"])
|
| 413 |
-
del self.__objects[obj_id]
|
| 414 |
-
self.__update_view()
|
| 415 |
-
|
| 416 |
-
def update_object(self, oid=0, vertices=None, colors=None, faces=None):
|
| 417 |
-
obj = self.__objects[oid]
|
| 418 |
-
if type(vertices) != type(None):
|
| 419 |
-
if obj["coloring"] == "FaceColors":
|
| 420 |
-
f = obj["arrays"][1]
|
| 421 |
-
verts = np.zeros((f.shape[0] * 3, 3), dtype="float32")
|
| 422 |
-
for ii in range(f.shape[0]):
|
| 423 |
-
# print(ii*3, f[ii])
|
| 424 |
-
verts[ii * 3] = vertices[f[ii, 0]]
|
| 425 |
-
verts[ii * 3 + 1] = vertices[f[ii, 1]]
|
| 426 |
-
verts[ii * 3 + 2] = vertices[f[ii, 2]]
|
| 427 |
-
v = verts
|
| 428 |
-
|
| 429 |
-
else:
|
| 430 |
-
v = vertices.astype("float32", copy=False)
|
| 431 |
-
obj["geometry"].attributes["position"].array = v
|
| 432 |
-
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
| 433 |
-
obj["geometry"].attributes["position"].needsUpdate = True
|
| 434 |
-
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
| 435 |
-
if type(colors) != type(None):
|
| 436 |
-
colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"])
|
| 437 |
-
colors = colors.astype("float32", copy=False)
|
| 438 |
-
obj["geometry"].attributes["color"].array = colors
|
| 439 |
-
obj["geometry"].attributes["color"].needsUpdate = True
|
| 440 |
-
if type(faces) != type(None):
|
| 441 |
-
if obj["coloring"] == "FaceColors":
|
| 442 |
-
print("Face updates are currently only possible in vertex color mode.")
|
| 443 |
-
return
|
| 444 |
-
f = faces.astype("uint32", copy=False).ravel()
|
| 445 |
-
print(obj["geometry"].attributes)
|
| 446 |
-
obj["geometry"].attributes["index"].array = f
|
| 447 |
-
# self.wireframe.attributes["position"].array = v # Wireframe updates?
|
| 448 |
-
obj["geometry"].attributes["index"].needsUpdate = True
|
| 449 |
-
# obj["geometry"].exec_three_obj_method('computeVertexNormals')
|
| 450 |
-
# self.mesh.geometry.verticesNeedUpdate = True
|
| 451 |
-
# self.mesh.geometry.elementsNeedUpdate = True
|
| 452 |
-
# self.update()
|
| 453 |
-
if self.render_mode == "WEBSITE":
|
| 454 |
-
return self
|
| 455 |
-
|
| 456 |
-
# def update(self):
|
| 457 |
-
# self.mesh.exec_three_obj_method('update')
|
| 458 |
-
# self.orbit.exec_three_obj_method('update')
|
| 459 |
-
# self.cam.exec_three_obj_method('updateProjectionMatrix')
|
| 460 |
-
# self.scene.exec_three_obj_method('update')
|
| 461 |
-
|
| 462 |
-
def add_text(self, text, shading={}, **kwargs):
|
| 463 |
-
shading.update(kwargs)
|
| 464 |
-
sh = self.__get_shading(shading)
|
| 465 |
-
tt = p3s.TextTexture(string=text, color=sh["text_color"])
|
| 466 |
-
sm = p3s.SpriteMaterial(map=tt)
|
| 467 |
-
text = p3s.Sprite(material=sm, scaleToTexture=True)
|
| 468 |
-
self._scene.add(text)
|
| 469 |
-
|
| 470 |
-
# def add_widget(self, widget, callback):
|
| 471 |
-
# self.widgets.append(widget)
|
| 472 |
-
# widget.observe(callback, names='value')
|
| 473 |
-
|
| 474 |
-
# def add_dropdown(self, options, default, desc, cb):
|
| 475 |
-
# widget = widgets.Dropdown(options=options, value=default, description=desc)
|
| 476 |
-
# self.__widgets.append(widget)
|
| 477 |
-
# widget.observe(cb, names="value")
|
| 478 |
-
# display(widget)
|
| 479 |
-
|
| 480 |
-
# def add_button(self, text, cb):
|
| 481 |
-
# button = widgets.Button(description=text)
|
| 482 |
-
# self.__widgets.append(button)
|
| 483 |
-
# button.on_click(cb)
|
| 484 |
-
# display(button)
|
| 485 |
-
|
| 486 |
-
def to_html(self, imports=True, html_frame=True):
|
| 487 |
-
# Bake positions (fixes centering bug in offline rendering)
|
| 488 |
-
if len(self.__objects) == 0:
|
| 489 |
-
return
|
| 490 |
-
ma = np.zeros((len(self.__objects), 3))
|
| 491 |
-
mi = np.zeros((len(self.__objects), 3))
|
| 492 |
-
for r, obj in enumerate(self.__objects):
|
| 493 |
-
ma[r] = self.__objects[obj]["max"]
|
| 494 |
-
mi[r] = self.__objects[obj]["min"]
|
| 495 |
-
ma = np.max(ma, axis=0)
|
| 496 |
-
mi = np.min(mi, axis=0)
|
| 497 |
-
diag = np.linalg.norm(ma - mi)
|
| 498 |
-
mean = (ma - mi) / 2 + mi
|
| 499 |
-
for r, obj in enumerate(self.__objects):
|
| 500 |
-
v = self.__objects[obj]["geometry"].attributes["position"].array
|
| 501 |
-
v -= mean
|
| 502 |
-
# v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window
|
| 503 |
-
|
| 504 |
-
scale = self.__s["scale"] * (diag)
|
| 505 |
-
self._orbit.target = [0.0, 0.0, 0.0]
|
| 506 |
-
self._cam.lookAt([0.0, 0.0, 0.0])
|
| 507 |
-
# self._cam.position = [0.0, 0.0, scale]
|
| 508 |
-
self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window
|
| 509 |
-
self._light.position = [0.0, 0.0, scale]
|
| 510 |
-
|
| 511 |
-
state = embed.dependency_state(self._renderer)
|
| 512 |
-
|
| 513 |
-
# Somehow these entries are missing when the state is exported in python.
|
| 514 |
-
# Exporting from the GUI works, so we are inserting the missing entries.
|
| 515 |
-
for k in state:
|
| 516 |
-
if state[k]["model_name"] == "OrbitControlsModel":
|
| 517 |
-
state[k]["state"]["maxAzimuthAngle"] = "inf"
|
| 518 |
-
state[k]["state"]["maxDistance"] = "inf"
|
| 519 |
-
state[k]["state"]["maxZoom"] = "inf"
|
| 520 |
-
state[k]["state"]["minAzimuthAngle"] = "-inf"
|
| 521 |
-
|
| 522 |
-
tpl = embed.load_requirejs_template
|
| 523 |
-
if not imports:
|
| 524 |
-
embed.load_requirejs_template = ""
|
| 525 |
-
|
| 526 |
-
s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL)
|
| 527 |
-
# s = embed.embed_snippet(self.__w, state=state)
|
| 528 |
-
embed.load_requirejs_template = tpl
|
| 529 |
-
|
| 530 |
-
if html_frame:
|
| 531 |
-
s = "<html>\n<body>\n" + s + "\n</body>\n</html>"
|
| 532 |
-
|
| 533 |
-
# Revert changes
|
| 534 |
-
for r, obj in enumerate(self.__objects):
|
| 535 |
-
v = self.__objects[obj]["geometry"].attributes["position"].array
|
| 536 |
-
v += mean
|
| 537 |
-
self.__update_view()
|
| 538 |
-
|
| 539 |
-
return s
|
| 540 |
-
|
| 541 |
-
def save(self, filename=""):
|
| 542 |
-
if filename == "":
|
| 543 |
-
uid = str(uuid.uuid4()) + ".html"
|
| 544 |
-
else:
|
| 545 |
-
filename = filename.replace(".html", "")
|
| 546 |
-
uid = filename + '.html'
|
| 547 |
-
with open(uid, "w") as f:
|
| 548 |
-
f.write(self.to_html())
|
| 549 |
-
print("Plot saved to file %s." % uid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ultrashape/utils/voxelize.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
def voxelize_from_point(pc, num_latents, resolution=128):
|
| 4 |
-
|
| 5 |
-
B, N, D = pc.shape
|
| 6 |
-
device = pc.device
|
| 7 |
-
|
| 8 |
-
norm_pc = (pc + 1.0) / 2.0
|
| 9 |
-
voxel_indices = torch.floor(norm_pc * resolution).long()
|
| 10 |
-
voxel_indices = torch.clamp(voxel_indices, 0, resolution - 1) # (B, N, 3)
|
| 11 |
-
|
| 12 |
-
batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
|
| 13 |
-
flat_indices = torch.cat([batch_idx.unsqueeze(-1), voxel_indices], dim=-1).view(-1, 4)
|
| 14 |
-
unique_voxels = torch.unique(flat_indices, dim=0)
|
| 15 |
-
u_batch_ids = unique_voxels[:, 0]
|
| 16 |
-
|
| 17 |
-
noise = torch.rand_like(u_batch_ids, dtype=torch.float)
|
| 18 |
-
sort_keys = u_batch_ids.float() + noise
|
| 19 |
-
perm = torch.argsort(sort_keys)
|
| 20 |
-
shuffled_voxels = unique_voxels[perm]
|
| 21 |
-
shuffled_batch_ids = shuffled_voxels[:, 0].contiguous()
|
| 22 |
-
|
| 23 |
-
counts = torch.bincount(shuffled_batch_ids, minlength=B)
|
| 24 |
-
min_count = counts.min().item()
|
| 25 |
-
|
| 26 |
-
# Always aim for num_latents
|
| 27 |
-
actual_k = num_latents
|
| 28 |
-
|
| 29 |
-
if min_count < num_latents:
|
| 30 |
-
print(f"[Info] Voxel count ({min_count}) < Target ({num_latents}). Sampling with replacement.")
|
| 31 |
-
# If we don't have enough unique voxels, we need to sample with replacement/repetition
|
| 32 |
-
# We can just repeat the indices to fill the gap
|
| 33 |
-
|
| 34 |
-
batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))
|
| 35 |
-
|
| 36 |
-
# Create gathering indices that wrap around for each batch
|
| 37 |
-
# For each batch element i, we want actual_k indices
|
| 38 |
-
# They start at batch_starts[i] and go up to batch_starts[i] + counts[i]
|
| 39 |
-
# We use modulo to wrap around: (j % counts[i]) + batch_starts[i]
|
| 40 |
-
|
| 41 |
-
# Expand for broadcasting
|
| 42 |
-
batch_starts_exp = batch_starts.unsqueeze(1) # [B, 1]
|
| 43 |
-
counts_exp = counts.unsqueeze(1) # [B, 1]
|
| 44 |
-
|
| 45 |
-
offsets = torch.arange(actual_k, device=device).unsqueeze(0) # [1, K]
|
| 46 |
-
|
| 47 |
-
# Calculate offsets modulo the available count for each batch
|
| 48 |
-
# This effectively repeats the available voxels to fill the desired size
|
| 49 |
-
# We need to be careful about division by zero if a batch has 0 voxels (shouldn't happen with valid PC)
|
| 50 |
-
counts_exp = torch.maximum(counts_exp, torch.tensor(1, device=device))
|
| 51 |
-
|
| 52 |
-
wrapped_offsets = offsets % counts_exp
|
| 53 |
-
gather_indices = batch_starts_exp + wrapped_offsets
|
| 54 |
-
gather_indices = gather_indices.view(-1)
|
| 55 |
-
|
| 56 |
-
else:
|
| 57 |
-
# Standard case: enough points, just take the first k
|
| 58 |
-
batch_starts = torch.searchsorted(shuffled_batch_ids, torch.arange(B, device=device))
|
| 59 |
-
offsets = torch.arange(actual_k, device=device).unsqueeze(0)
|
| 60 |
-
gather_indices = batch_starts.unsqueeze(1) + offsets
|
| 61 |
-
gather_indices = gather_indices.view(-1)
|
| 62 |
-
|
| 63 |
-
selected_indices = shuffled_voxels[gather_indices]
|
| 64 |
-
|
| 65 |
-
final_grid_coords = selected_indices[:, 1:]
|
| 66 |
-
|
| 67 |
-
# Grid Index -> Voxel Center
|
| 68 |
-
voxel_size = 2.0 / resolution
|
| 69 |
-
final_centers = (final_grid_coords.float() + 0.5) * voxel_size - 1.0
|
| 70 |
-
|
| 71 |
-
sampled_pc = final_centers.view(B, actual_k, 3)
|
| 72 |
-
sampled_indices = final_grid_coords.view(B, actual_k, 3)
|
| 73 |
-
|
| 74 |
-
return sampled_pc, sampled_indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|