Extend3D / extend3d.py
Seungwoo-Yoon
initial commit for HF space
a68e3ed
import os
import json
import numpy as np
from PIL import Image
from typing import List
from tqdm import tqdm, trange
os.environ['SPCONV_ALGO'] = 'native'
import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
from trellis.pipelines.base import Pipeline
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.models import SparseStructureFlowModel, SparseStructureEncoder, SparseStructureDecoder
from trellis.modules.sparse.basic import sparse_cat, sparse_unbind, SparseTensor
from trellis.utils import render_utils
from trellis.representations.mesh import MeshExtractResult
from trellis.representations.mesh.utils_cube import sparse_cube2verts
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from utils import *
class Extend3D(Pipeline):
# -----------------------------------------------------------------------
# Construction
# -----------------------------------------------------------------------
def __init__(self, ckpt_path: str, device: str = 'cpu'):
super().__init__()
# Load the base Trellis pipeline
self.pipeline = TrellisImageTo3DPipeline.from_pretrained(ckpt_path)
self.pipeline.to(device)
self.models = self.pipeline.models
# Replace the sparse-structure encoder with a higher-capacity checkpoint
config_path = hf_hub_download(repo_id=ckpt_path,
filename='ckpts/ss_enc_conv3d_16l8_fp16.json')
model_path = hf_hub_download(repo_id=ckpt_path,
filename='ckpts/ss_enc_conv3d_16l8_fp16.safetensors')
with open(config_path, 'r') as f:
model_config = json.load(f)
state_dict = load_file(model_path)
encoder = SparseStructureEncoder(**model_config['args'])
encoder.load_state_dict(state_dict)
self.models['sparse_structure_encoder'] = encoder.to(device)
# Perceptual metrics used for SLAT optimization loss (frozen, no gradients needed)
self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True, net_type='squeeze').to(device)
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
self.lpips.requires_grad_(False)
self.ssim.requires_grad_(False)
# SLAT normalization constants (frozen; gradients must not flow through them)
self.std = torch.tensor(self.pipeline.slat_normalization['std'])[None].to(device)
self.mean = torch.tensor(self.pipeline.slat_normalization['mean'])[None].to(device)
self.std.requires_grad_(False)
self.mean.requires_grad_(False)
# -----------------------------------------------------------------------
# Device management
# -----------------------------------------------------------------------
def to(self, device) -> "Extend3D":
self.pipeline.to(device)
self.models['sparse_structure_encoder'] = self.models['sparse_structure_encoder'].to(device)
self.lpips = self.lpips.to(device)
self.ssim = self.ssim.to(device)
self.std = self.std.to(device)
self.mean = self.mean.to(device)
return self
def cuda(self) -> "Extend3D":
return self.to(torch.device('cuda'))
def cpu(self) -> "Extend3D":
return self.to(torch.device('cpu'))
@staticmethod
def from_pretrained(ckpt_path: str, device: str = 'cpu') -> "Extend3D":
return Extend3D(ckpt_path, device=device)
# -----------------------------------------------------------------------
# Preprocessing
# -----------------------------------------------------------------------
@staticmethod
def preprocess(image: Image.Image) -> Image.Image:
return image.resize((1024, 1024), Image.Resampling.LANCZOS)
# -----------------------------------------------------------------------
# Conditioning
# -----------------------------------------------------------------------
@torch.no_grad()
def get_cond(
self,
image: Image.Image,
pointmap_info: PointmapInfo = None,
width: int = 2,
length: int = 2,
div: int = 2,
) -> List[List[dict]]:
"""Compute per-patch image conditioning for the flow model."""
if pointmap_info is None:
pointmap_info = PointmapInfo(image, device=self.device)
patches = pointmap_info.divide_image(width, length, div)
return [
[self.pipeline.get_cond([self.preprocess(patch)]) for patch in row]
for row in patches
]
# -----------------------------------------------------------------------
# Stage 1: Sparse structure sampling
# -----------------------------------------------------------------------
def sample_sparse_structure(
self,
image: Image.Image,
pointmap_info: PointmapInfo = None,
optim: bool = True,
width: int = 2,
length: int = 2,
div: int = 2,
iterations: int = 3,
steps: int = 25,
rescale_t: float = 3.0,
t_noise: float = 0.6,
t_start: float = 0.8,
cfg_strength: float = 7.5,
alpha: float = 5.0,
batch_size: int = 1,
progress_callback=None,
) -> torch.Tensor:
"""
Sample occupied voxel coordinates via iterative flow-matching.
Returns:
coords: int32 tensor of shape [N, 4] (batch, y, x, z).
"""
if pointmap_info is None:
pointmap_info = PointmapInfo(image, device=self.device)
flow_model: SparseStructureFlowModel = self.models['sparse_structure_flow_model']
encoder: SparseStructureEncoder = self.models['sparse_structure_encoder']
decoder: SparseStructureDecoder = self.models['sparse_structure_decoder']
sampler = self.pipeline.sparse_structure_sampler
cfg_interval = self.pipeline.sparse_structure_sampler_params['cfg_interval']
for p in decoder.parameters():
p.requires_grad_(False)
sigma_min = sampler.sigma_min
reso = flow_model.resolution
# Build point cloud from the pointmap info
pc = torch.tensor(pointmap_info.point_cloud(), dtype=torch.float32)
pc[:, 2] *= max(width, length)
# Encode initial voxel from the point cloud
voxel = pointcloud_to_voxel(pc, (4 * reso * length, 4 * reso * width, 4 * reso))
voxel = voxel.permute(0, 1, 3, 2, 4).float().to(self.device)
encoded_voxel = encoder(voxel)
pc = pc.to(self.device)
_, t_pairs = schedule(steps, rescale_t, start=t_start)
views = get_views(width, length, reso, div)
# Latent tensor and accumulation buffers
latent = torch.randn(1, flow_model.in_channels, reso * width, reso * length, reso,
device=self.device)
count = torch.zeros_like(latent)
value = torch.zeros_like(latent)
global_cond = self.get_cond(image, pointmap_info, 1, 1, 1)[0][0]
cond = self.get_cond(image, pointmap_info, width, length, div)
total_steps = iterations * len(t_pairs)
global_step = 0
iter_range = trange(iterations, position=0) if progress_callback is None else range(iterations)
for it in iter_range:
# Noise the latent to t_noise at the start of each iteration
latent = diffuse(encoded_voxel, torch.tensor(t_noise, device=self.device), sigma_min)
latent = latent.detach()
step_iter = (tqdm(t_pairs, desc="Sparse Structure Sampling", position=1)
if progress_callback is None else t_pairs)
for t, t_prev in step_iter:
cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (1 - torch.tensor(t))))
c = cosine_factor ** alpha
with torch.no_grad():
# --- 1. Overlapping patch-wise flow ---
count.zero_()
value.zero_()
local_latents, patch_conds, patch_neg_conds, patch_views = [], [], [], []
for view in views:
i, j, y0, y1, x0, x1 = view
patch_views.append(view)
local_latents.append(latent[:, :, y0:y1, x0:x1, :].contiguous())
patch_cond = cond[i][j]
patch_conds.append(patch_cond['cond'])
patch_neg_conds.append(patch_cond['neg_cond'])
for start in range(0, len(local_latents), batch_size):
end = min(start + batch_size, len(local_latents))
out = sampler.sample_once(
flow_model,
torch.cat(local_latents[start:end], dim=0),
t, t_prev,
cond=torch.cat(patch_conds[start:end], dim=0),
neg_cond=torch.cat(patch_neg_conds[start:end], dim=0),
cfg_strength=cfg_strength,
cfg_interval=cfg_interval,
)
for view, pred_v in zip(patch_views[start:end], out.pred_v):
_, _, y0, y1, x0, x1 = view
count[:, :, y0:y1, x0:x1, :] += 1
value[:, :, y0:y1, x0:x1, :] += pred_v
local_pred_v = torch.where(count > 0, value / count, latent)
# --- 2. Dilated sampling (global structure) ---
count.zero_()
value.zero_()
dilated_samples = dilated_sampling(reso, width, length)
dilated_latents = []
dilated_conds = []
dilated_neg_conds = []
for sample in dilated_samples:
sample_latent = (latent[:, :, sample[:, 0], sample[:, 1], :]
.view(1, flow_model.in_channels, reso, reso, reso))
dilated_latents.append(sample_latent)
dilated_conds.append(global_cond['cond'])
dilated_neg_conds.append(global_cond['neg_cond'])
for start in range(0, len(dilated_latents), batch_size):
end = min(start + batch_size, len(dilated_latents))
out = sampler.sample_once(
flow_model,
torch.cat(dilated_latents[start:end], dim=0),
t, t_prev,
cond=torch.cat(dilated_conds[start:end], dim=0),
neg_cond=torch.cat(dilated_neg_conds[start:end], dim=0),
cfg_strength=cfg_strength,
cfg_interval=cfg_interval,
)
for sample, pred_v in zip(dilated_samples[start:end], out.pred_v):
count[:, :, sample[:, 0], sample[:, 1], :] += 1
value[:, :, sample[:, 0], sample[:, 1], :] += pred_v.view(
1, flow_model.in_channels, reso * reso, reso
)
global_pred_v = torch.where(count > 0, value / count, latent)
# Blend local and global velocity predictions
v = local_pred_v * (1 - c) + global_pred_v * c
v = v.detach()
# Enable grad so that Adam can optimize v as a leaf variable
v.requires_grad_()
v.retain_grad()
optimizer = torch.optim.Adam([v], lr=0.1)
if optim and t < 0.7:
for _ in range(20):
optimizer.zero_grad()
pred_latent = (1 - sigma_min) * latent - (sigma_min + (1 - sigma_min) * t) * v
decoded_latent = decoder(pred_latent)
loss = sparse_structure_loss(pc, decoded_latent.permute(0, 1, 3, 2, 4))
loss.backward()
optimizer.step()
# Euler step
latent = (latent - (t - t_prev) * v).detach()
if progress_callback is not None:
global_step += 1
progress_callback(
global_step / total_steps,
f"Sparse Structure: iter {it + 1}/{iterations}, step {global_step}/{total_steps}",
)
# Re-encode the decoded voxel for the next iteration
voxel = (decoder(latent) > 0).float()
encoded_voxel = encoder(voxel)
coords = torch.argwhere(decoder(latent) > 0)[:, [0, 2, 3, 4]].int()
return coords
# -----------------------------------------------------------------------
# Stage 2: Structured latent (SLAT) sampling
# -----------------------------------------------------------------------
def sample_slat(
self,
image: Image.Image,
coords: torch.Tensor,
pointmap_info: PointmapInfo = None,
optim: bool = True,
width: int = 2,
length: int = 2,
div: int = 2,
steps: int = 25,
rescale_t: float = 3.0,
cfg_strength: float = 3.0,
batch_size: int = 1,
progress_callback=None,
) -> SparseTensor:
"""
Sample per-voxel latent features (SLAT) via flow-matching.
Returns:
slat: SparseTensor with denormalized latent features.
"""
if pointmap_info is None:
pointmap_info = PointmapInfo(image, device=self.device)
# Prepare reference image tensor for perceptual optimization loss
resized_image = image.resize((512, 512))
tensor_image = (torch.from_numpy(np.array(resized_image))
.permute(2, 0, 1).float() / 255.0).to(self.device)
intrinsic = torch.tensor(pointmap_info.camera_intrinsic(), dtype=torch.float32).to(self.device)
extrinsic = torch.tensor(pointmap_info.camera_extrinsic(), dtype=torch.float32).to(self.device)
flow_model = self.models['slat_flow_model']
sampler = self.pipeline.slat_sampler
cfg_interval = self.pipeline.slat_sampler_params['cfg_interval']
cond = self.get_cond(image, pointmap_info, width, length, div)
sigma_min = sampler.sigma_min
reso = flow_model.resolution
latent_feats = torch.randn(coords.shape[0], flow_model.in_channels, device=self.device)
# Pre-compute where each voxel coordinate falls in the overlapping patch grid
views = get_views(width, length, reso, div)
valid_views = []
patch_indices = []
for i, j, y0, y1, x0, x1 in views:
idx = torch.where(
(coords[:, 1] >= y0) & (coords[:, 1] < y1) &
(coords[:, 2] >= x0) & (coords[:, 2] < x1)
)[0]
if len(idx) > 0:
valid_views.append((i, j, y0, y1, x0, x1))
patch_indices.append(idx)
count = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device)
value = torch.zeros(coords.shape[0], flow_model.in_channels, device=self.device)
_, t_pairs = schedule(steps, rescale_t)
total_steps = len(t_pairs)
step_iter = (tqdm(t_pairs, desc="Structured Latent Sampling")
if progress_callback is None else t_pairs)
for slat_step, (t, t_prev) in enumerate(step_iter, start=1):
with torch.no_grad():
count.zero_()
value.zero_()
patch_latents = []
patch_conds = []
for view, patch_index in zip(valid_views, patch_indices):
i, j, y0, y1, x0, x1 = view
patch_conds.append(cond[i][j])
patch_coords_local = coords[patch_index].clone()
patch_coords_local[:, 1] -= y0
patch_coords_local[:, 2] -= x0
patch_latents.append(SparseTensor(
feats=latent_feats[patch_index].contiguous(),
coords=patch_coords_local,
))
for start in range(0, len(patch_latents), batch_size):
end = min(start + batch_size, len(patch_latents))
conds_chunk = patch_conds[start:end]
batched_cond = {
k: torch.cat([d[k] for d in conds_chunk], dim=0)
for k in conds_chunk[0].keys()
}
outs = sampler.sample_once(
flow_model,
sparse_cat(patch_latents[start:end]),
t, t_prev,
cfg_strength=cfg_strength,
cfg_interval=cfg_interval,
**batched_cond,
)
for out, pidx in zip(sparse_unbind(outs.pred_v, dim=0), patch_indices[start:end]):
count[pidx, :] += 1
value[pidx, :] += out.feats
v_feats = torch.where(count > 0, value / count, latent_feats).detach()
# Enable grad for leaf-variable optimization
v_feats.requires_grad_()
optimizer = torch.optim.Adam([v_feats], lr=0.3)
if optim and t < 0.8:
for _ in range(20):
optimizer.zero_grad()
pred_feats = (1 - sigma_min) * latent_feats - (sigma_min + (1 - sigma_min) * t) * v_feats
pred_slat = SparseTensor(feats=pred_feats, coords=coords) * self.std + self.mean
rendered = render_utils.render_frames_torch(
self.decode_slat(pred_slat, width, length, formats=['gaussian'])['gaussian'][0],
[extrinsic], [intrinsic],
{'resolution': 512, 'bg_color': (0, 0, 0)},
verbose=False,
)['color'][0].permute(2, 1, 0)
loss = (self.lpips(rendered.unsqueeze(0), tensor_image.unsqueeze(0))
- self.ssim(rendered.unsqueeze(0), tensor_image.unsqueeze(0)))
loss.backward()
optimizer.step()
# Euler step; detach to free the computation graph
latent_feats = (latent_feats - (t - t_prev) * v_feats).detach()
if progress_callback is not None:
progress_callback(slat_step / total_steps,
f"SLAT Sampling: step {slat_step}/{total_steps}")
slat = SparseTensor(feats=latent_feats, coords=coords)
return slat * self.std + self.mean
# -----------------------------------------------------------------------
# Stage 3: Decode SLAT → Gaussians and/or mesh
# -----------------------------------------------------------------------
def decode_slat(
self,
slat: SparseTensor,
width: int,
length: int,
formats: list[str] = ['gaussian', 'mesh'],
) -> dict:
"""Decode a structured latent into Gaussian splats and/or a triangle mesh."""
ret = {}
feats = slat.feats
coords = slat.coords
reso = self.models['slat_flow_model'].resolution
scale = max(width, length)
# -------------------------------------------------------------------
# Mesh decoding
# -------------------------------------------------------------------
if 'mesh' in formats:
mesh_decoder = self.pipeline.models['slat_decoder_mesh']
sf2m = mesh_decoder.mesh_extractor # SparseFeatures2Mesh
# Global high-res grid dimensions (4× upsampling from SLAT resolution)
up_res = mesh_decoder.resolution * 4
res_y, res_x, res_z = width * up_res, length * up_res, up_res
# Accumulate high-res sparse features across overlapping patches with cosine blending
C = sf2m.feats_channels
global_sum = torch.zeros(res_y, res_x, res_z, C, device=self.device)
global_count = torch.zeros(res_y, res_x, res_z, 1, device=self.device)
for _, _, y_start, y_end, x_start, x_end in get_views(width, length, reso, 4):
patch_index = torch.where(
(coords[:, 1] >= y_start) & (coords[:, 1] < y_end) &
(coords[:, 2] >= x_start) & (coords[:, 2] < x_end)
)[0]
if len(patch_index) == 0:
continue
patch_coords = coords[patch_index].clone()
patch_coords[:, 1] -= y_start
patch_coords[:, 2] -= x_start
patch_latent = SparseTensor(
feats=feats[patch_index].contiguous(),
coords=patch_coords,
)
patch_hr = mesh_decoder.forward_features(patch_latent)
# Cosine spatial weight: 1 at patch center, 0 at edges
hr_coords = patch_hr.coords[:, 1:].clone() # [N, 3]
patch_size = float(4 * reso)
cos_w = (torch.cos(torch.pi * (hr_coords[:, 0].float() / patch_size - 0.5))
* torch.cos(torch.pi * (hr_coords[:, 1].float() / patch_size - 0.5))
).unsqueeze(1) # [N, 1]
# Shift to global coordinates
hr_coords[:, 0] = (hr_coords[:, 0] + 4 * y_start).clamp(0, res_y - 1)
hr_coords[:, 1] = (hr_coords[:, 1] + 4 * x_start).clamp(0, res_x - 1)
hr_coords[:, 2] = hr_coords[:, 2].clamp(0, res_z - 1)
gy, gx, gz = hr_coords[:, 0], hr_coords[:, 1], hr_coords[:, 2]
global_sum [gy, gx, gz] += patch_hr.feats * cos_w
global_count[gy, gx, gz] += cos_w
# Average overlapping regions
occupied = global_count[..., 0] > 0
global_sum[occupied] /= global_count[occupied]
if occupied.any():
occ_coords = torch.argwhere(occupied)
occ_feats = global_sum[occ_coords[:, 0], occ_coords[:, 1], occ_coords[:, 2]]
# Extract per-cube SDF, deformation, color, and FlexiCubes weights
sdf = sf2m.get_layout(occ_feats, 'sdf') + sf2m.sdf_bias # [N, 8, 1]
deform = sf2m.get_layout(occ_feats, 'deform') # [N, 8, 3]
color = sf2m.get_layout(occ_feats, 'color') # [N, 8, 6] or None
weights = sf2m.get_layout(occ_feats, 'weights') # [N, 21]
v_attrs_cat = (torch.cat([sdf, deform, color], dim=-1)
if sf2m.use_color else torch.cat([sdf, deform], dim=-1))
# Merge cube corners into unique vertices
v_pos, v_attrs, _ = sparse_cube2verts(occ_coords, v_attrs_cat, training=False)
# Build flat dense vertex attribute array for the global grid
res_vy, res_vx, res_vz = res_y + 1, res_x + 1, res_z + 1
v_attrs_d = torch.zeros(res_vy * res_vx * res_vz, v_attrs.shape[-1], device=self.device)
v_attrs_d[:, 0] = 1.0 # SDF default: outside surface
vert_ids = v_pos[:, 0] * res_vx * res_vz + v_pos[:, 1] * res_vz + v_pos[:, 2]
v_attrs_d[vert_ids] = v_attrs
sdf_d = v_attrs_d[:, 0]
deform_d = v_attrs_d[:, 1:4]
colors_d = v_attrs_d[:, 4:] if sf2m.use_color else None
# Build flat dense cube weight array
weights_d = torch.zeros(res_y * res_x * res_z, weights.shape[-1], device=self.device)
cube_ids = occ_coords[:, 0] * res_x * res_z + occ_coords[:, 1] * res_z + occ_coords[:, 2]
weights_d[cube_ids] = weights
# Regular vertex position grid [V, 3], normalized to world space
ay, ax, az = (torch.arange(r, device=self.device, dtype=torch.float)
for r in (res_vy, res_vx, res_vz))
gy, gx, gz = torch.meshgrid(ay, ax, az, indexing='ij')
reg_v = torch.stack([gy.flatten(), gx.flatten(), gz.flatten()], dim=1)
# Normalize to Gaussian world coordinate convention:
# y, x : [-0.5, 0.5] (centered)
# z : [0, 1/scale] (not centered)
norm_val = scale * up_res
norm_t = torch.tensor([norm_val, norm_val, norm_val], device=self.device, dtype=torch.float)
offset_t = torch.tensor([0.5, 0.5, 0.0], device=self.device, dtype=torch.float)
x_nx3 = reg_v / norm_t - offset_t + (1 - 1e-8) / (norm_t * 2) * torch.tanh(deform_d)
# Global cube → 8 corner vertex index table [C_total, 8]
cy, cx, cz = (torch.arange(r, device=self.device) for r in (res_y, res_x, res_z))
gy, gx, gz = torch.meshgrid(cy, cx, cz, indexing='ij')
cc = torch.tensor(
[[0,0,0],[1,0,0],[0,1,0],[1,1,0],[0,0,1],[1,0,1],[0,1,1],[1,1,1]],
dtype=torch.long, device=self.device,
)
reg_c = ((gy.flatten().unsqueeze(1) + cc[:, 0]) * res_vx * res_vz
+ (gx.flatten().unsqueeze(1) + cc[:, 1]) * res_vz
+ (gz.flatten().unsqueeze(1) + cc[:, 2])) # [C, 8]
# Single FlexiCubes call on the full global SDF
vertices, faces, _, colors = sf2m.mesh_extractor(
voxelgrid_vertices=x_nx3,
scalar_field=sdf_d,
cube_idx=reg_c,
resolution=[res_y, res_x, res_z],
beta=weights_d[:, :12],
alpha=weights_d[:, 12:20],
gamma_f=weights_d[:, 20],
voxelgrid_colors=colors_d,
training=False,
)
ret['mesh'] = [MeshExtractResult(
vertices=vertices,
faces=faces,
vertex_attrs=colors,
res=max(res_y, res_x, res_z),
)]
else:
ret['mesh'] = []
# -------------------------------------------------------------------
# Gaussian decoding
# -------------------------------------------------------------------
if 'gaussian' in formats:
gs_decoder = self.pipeline.models['slat_decoder_gs']
# Decode each patch and collect Gaussian lists per batch element
all_patch_lists: list | None = None
for i in range(width):
for j in range(length):
y0, y1 = i * reso, (i + 1) * reso
x0, x1 = j * reso, (j + 1) * reso
patch_index = torch.where(
(coords[:, 1] >= y0) & (coords[:, 1] < y1) &
(coords[:, 2] >= x0) & (coords[:, 2] < x1)
)[0]
if len(patch_index) == 0:
continue
patch_coords = coords[patch_index].clone()
patch_coords[:, 1] -= y0
patch_coords[:, 2] -= x0
patch_latent = SparseTensor(
feats=feats[patch_index].contiguous(),
coords=patch_coords,
)
patch_gaussians = gs_decoder(patch_latent)
# Translate Gaussians to their world-space tile position
offset = torch.tensor([[i + 0.5, j + 0.5, 0.5]], device=self.device)
for g in patch_gaussians:
g._xyz = g._xyz + offset
if all_patch_lists is None:
all_patch_lists = [[g] for g in patch_gaussians]
else:
for k, g in enumerate(patch_gaussians):
all_patch_lists[k].append(g)
# Concatenate all patches into a single Gaussian set per batch element
merged_gaussians = []
for gs_list in all_patch_lists:
g0 = gs_list[0]
if len(gs_list) > 1:
g0._features_dc = torch.cat([g._features_dc for g in gs_list], dim=0)
g0._opacity = torch.cat([g._opacity for g in gs_list], dim=0)
g0._rotation = torch.cat([g._rotation for g in gs_list], dim=0)
g0._scaling = torch.cat([g._scaling for g in gs_list], dim=0)
g0._xyz = torch.cat([g._xyz for g in gs_list], dim=0)
merged_gaussians.append(g0)
# Filter Gaussians with overly large kernels (outliers)
for g in merged_gaussians:
scale_norm = torch.sum(g.get_scaling ** 2, dim=1) ** 0.5
keep = torch.where(scale_norm < 0.03)[0]
g._features_dc = g._features_dc[keep]
g._opacity = g._opacity[keep]
g._rotation = g._rotation[keep]
g._scaling = g._scaling[keep]
g._xyz = g._xyz[keep]
# Normalize to world-space coordinate convention
eps = 1e-4
center_offset = torch.tensor([[0.5, 0.5, 0.0]], device=self.device)
for g in merged_gaussians:
g.from_xyz(g.get_xyz / scale)
g._xyz -= center_offset
g.mininum_kernel_size /= scale
g.from_scaling(torch.max(
g.get_scaling / scale,
torch.tensor(g.mininum_kernel_size * (1 + eps), device=self.device),
))
ret['gaussian'] = merged_gaussians
return ret
# -----------------------------------------------------------------------
# Full pipeline
# -----------------------------------------------------------------------
def run(
self,
image: Image.Image,
width: int = 2,
length: int = 2,
div: int = 2,
ss_optim: bool = True,
ss_iterations: int = 3,
ss_steps: int = 25,
ss_rescale_t: float = 3.0,
ss_t_noise: float = 0.6,
ss_t_start: float = 0.8,
ss_cfg_strength: float = 7.5,
ss_alpha: float = 5.0,
ss_batch_size: int = 1,
slat_optim: bool = True,
slat_steps: int = 25,
slat_rescale_t: float = 3.0,
slat_cfg_strength: float = 3.0,
slat_batch_size: int = 1,
formats: list = ['gaussian', 'mesh'],
return_pointmap: bool = False,
progress_callback=None,
) -> dict:
"""Run the full Extend3D pipeline: SS sampling → SLAT sampling → decode."""
pointmap_info = PointmapInfoMoGe(image, device=self.device)
coords = self.sample_sparse_structure(
image, pointmap_info, ss_optim, width, length, div,
iterations=ss_iterations,
steps=ss_steps,
rescale_t=ss_rescale_t,
t_noise=ss_t_noise,
t_start=ss_t_start,
cfg_strength=ss_cfg_strength,
alpha=ss_alpha,
batch_size=ss_batch_size,
progress_callback=progress_callback,
).detach()
slat = self.sample_slat(
image, coords, pointmap_info, slat_optim,
width, length, div,
steps=slat_steps,
rescale_t=slat_rescale_t,
cfg_strength=slat_cfg_strength,
batch_size=slat_batch_size,
progress_callback=progress_callback,
)
with torch.no_grad():
decoded = self.decode_slat(slat, width, length, formats=formats)
if return_pointmap:
return decoded, pointmap_info
return decoded