Spaces:
Running on Zero
Running on Zero
File size: 6,035 Bytes
b8db3e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import os
import json
from typing import *
import numpy as np
import torch
from ..representations import Voxel
from ..renderers import VoxelRenderer
from .components import StandardDatasetBase, ImageConditionedMixin
from .. import models
from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics
class SparseStructureLatentVisMixin:
def __init__(
self,
*args,
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json',
ss_dec_path: Optional[str] = None,
ss_dec_ckpt: Optional[str] = None,
**kwargs
):
super().__init__(*args, **kwargs)
self.ss_dec = None
self.pretrained_ss_dec = pretrained_ss_dec
self.ss_dec_path = ss_dec_path
self.ss_dec_ckpt = ss_dec_ckpt
def _loading_ss_dec(self):
if self.ss_dec is not None:
return
if self.ss_dec_path is not None:
cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_ss_dec)
self.ss_dec = decoder.cuda().eval()
def _delete_ss_dec(self):
del self.ss_dec
self.ss_dec = None
@torch.no_grad()
def decode_latent(self, z, batch_size=4):
self._loading_ss_dec()
ss = []
if self.normalization:
z = z * self.std.to(z.device) + self.mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
ss.append(self.ss_dec(z[i:i+batch_size]))
ss = torch.cat(ss, dim=0)
self._delete_ss_dec()
return ss
@torch.no_grad()
def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
x_0 = self.decode_latent(x_0.cuda())
renderer = VoxelRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4
# build camera
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
yaw = [y + yaw_offset for y in yaw]
pitch = [20 / 180 * np.pi for _ in range(4)]
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)
images = []
# Build each representation
x_0 = x_0.cuda()
for i in range(x_0.shape[0]):
coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
resolution = x_0.shape[-1]
color = coords / resolution
rep = Voxel(
origin=[-0.5, -0.5, -0.5],
voxel_size=1/resolution,
coords=coords,
attrs=color,
layout={
'color': slice(0, 3),
}
)
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(rep, ext, intr, colors_overwrite=color)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
return torch.stack(images)
class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
"""
Sparse structure latent dataset
Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
normalization (dict): normalization stats
pretrained_ss_dec (str): name of the pretrained sparse structure decoder
ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
"""
def __init__(self,
roots: str,
*,
min_aesthetic_score: float = 5.0,
normalization: Optional[dict] = None,
pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16',
ss_dec_path: Optional[str] = None,
ss_dec_ckpt: Optional[str] = None,
):
self.min_aesthetic_score = min_aesthetic_score
self.normalization = normalization
self.value_range = (0, 1)
super().__init__(
roots,
pretrained_ss_dec=pretrained_ss_dec,
ss_dec_path=ss_dec_path,
ss_dec_ckpt=ss_dec_ckpt,
)
if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata['ss_latent_encoded'] == True]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz'))
z = torch.tensor(latent['z']).float()
if self.normalization is not None:
z = (z - self.mean) / self.std
pack = {
'x_0': z,
}
return pack
class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent):
"""
Image-conditioned sparse structure dataset
"""
pass
|