Spaces:
Running on Zero
Running on Zero
File size: 6,397 Bytes
afea36f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | # Copied from the TRELLIS project:
# https://github.com/microsoft/TRELLIS
# Original license: MIT
# Copyright (c) the TRELLIS authors
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from ...utils.random_utils import hammersley_sequence
from .base import SparseTransformerBase
from ...representations import Gaussian
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
from .. import from_pretrained
class SLatGaussianDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
pretrained_gs_dec: str = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self._calc_layout()
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
self._build_perturbation()
self.initialize_weights()
if pretrained_gs_dec is not None:
if pretrained_gs_dec.endswith('.pt'):
print (f'loading pretrained weight: {pretrained_gs_dec}')
model_ckpt = torch.load(pretrained_gs_dec, map_location='cpu', weights_only=True)
self.load_state_dict(model_ckpt)
del model_ckpt
else:
print (f'loading pretrained weight: {pretrained_gs_dec}')
pre_trained_models: SLatGaussianDecoder
pre_trained_models = from_pretrained(pretrained_gs_dec)
self.load_state_dict(pre_trained_models.state_dict())
del pre_trained_models
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def _build_perturbation(self) -> None:
perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
perturbation = torch.tensor(perturbation).float() * 2 - 1
perturbation = perturbation / self.rep_config['voxel_size']
perturbation = torch.atanh(perturbation).to(self.device)
self.register_buffer('offset_perturbation', perturbation)
def _calc_layout(self) -> None:
self.layout = {
'_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
'_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
'_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
'_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
'_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
self.out_channels = start
def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
representation = Gaussian(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
scaling_bias = self.rep_config['scaling_bias'],
opacity_bias = self.rep_config['opacity_bias'],
scaling_activation = self.rep_config['scaling_activation']
)
xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
for k, v in self.layout.items():
if k == '_xyz':
offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
offset = offset * self.rep_config['lr'][k]
if self.rep_config['perturb_offset']:
offset = offset + self.offset_perturbation
offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
_xyz = xyz.unsqueeze(1) + offset
setattr(representation, k, _xyz.flatten(0, 1))
else:
feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
feats = feats * self.rep_config['lr'][k]
setattr(representation, k, feats)
ret.append(representation)
return ret
def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
h = super().forward(x)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return self.to_representation(h)
class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
"""
Slat VAE Gaussian decoder with elastic memory management.
Used for training with low VRAM.
"""
pass
|