File size: 4,273 Bytes
0d1388f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | from typing import *
from easydict import EasyDict as edict
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from ...modules import sparse as sp
from ...utils.random_utils import hammersley_sequence
from .base import SparseTransformerRegisterSelfBase
from ...representations import Gaussian_view as Gaussian
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
from .hdri_encoder import Hdri_Encoder
from .nerf_encoding import NeRFEncoding
class SLatGaussianDecoder(SparseTransformerRegisterSelfBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
cond_channels: int,
num_blocks: int,
num_register_tokens: int = 16,
pretrained_decoder_path: str = None,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
representation_config: dict = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
cond_channels=cond_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
qk_rms_norm_cross=qk_rms_norm_cross,
)
self.resolution = resolution
self.num_register_tokens = num_register_tokens
self.reg_tokens = nn.Parameter(torch.randn(1, num_register_tokens, model_channels))
self.initialize_weights()
if pretrained_decoder_path is not None:
if pretrained_decoder_path.endswith('.safetensors'):
decoder_weights = load_file(pretrained_decoder_path)
model_state_dict = self.state_dict()
for k, v in decoder_weights.items():
if k not in model_state_dict:
continue
elif k in ["input_layer.weight"]:
model_state_dict[k][:,:8] = v
model_state_dict[k][:,8:16] = v
model_state_dict[k][:,16:24] = v
else:
model_state_dict[k] = v
self.load_state_dict(model_state_dict, strict=True)
else:
decoder_weights = torch.load(pretrained_decoder_path, map_location='cpu', weights_only=True)
model_state_dict = self.state_dict()
for k, v in decoder_weights.items():
# if k not in model_state_dict:
# continue
if k in ["input_layer.weight"]:
model_state_dict[k][:,:8] = v
model_state_dict[k][:,8:16] = v
model_state_dict[k][:,16:24] = v
else:
model_state_dict[k] = v
self.load_state_dict(model_state_dict, strict=True)
print(f"Loaded pretrained decoder from {pretrained_decoder_path}")
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
def forward(self, x: sp.SparseTensor) -> Tuple[sp.SparseTensor, torch.Tensor]:
reg_feats = self.reg_tokens.expand(x.shape[0], -1, -1)
h, reg = super().forward(x, reg=reg_feats)
h = h.type(x.dtype)
reg = reg.type(x.dtype)
return h, reg
class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
"""
Slat VAE Gaussian decoder with elastic memory management.
Used for training with low VRAM.
"""
pass
|