Spaces:
Running
Running
File size: 5,732 Bytes
382733a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
import utils3d
from PIL import Image
import logging
import third_party.TRELLIS.trellis.modules.sparse as sp
from third_party.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
from lib.util import partfield, generation
# Global logger
log = logging.getLogger(__name__)
def optimize_appearance(cfg, output_dir):
log.info("Starting appearance optimization...")
generation_pipeline = TrellisImageTo3DPipeline.from_pretrained(cfg.trellis_img_model_name)
generation_pipeline.cuda()
# load appearance and structure data
path = osp.join(output_dir, 'latents', cfg.latent_name, "appearance.npz")
data = np.load(path)
app_feats = torch.from_numpy(data['feats']).cuda()
app_coords = torch.from_numpy(data['coords']).cuda()
struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0]
struct_coords = torch.from_numpy(struct_coords).float().cuda()
struct_coords = ((struct_coords + 0.5) * 64).long()
app_image = Image.open(osp.join(output_dir, 'app_image.png')).convert('RGB')
zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device)
struct_coords = torch.cat([zeros, struct_coords], dim=1)
# Load partfield planes
path = osp.join(output_dir, 'partfield', 'part_feat_struct_mesh_zup_batch_part_plane.npy')
struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()
path = osp.join(output_dir, 'partfield', 'part_feat_app_mesh_zup_batch_part_plane.npy')
app_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()
app_labels, struct_labels, point_feat1, point_feat2 = partfield.cosegment_part(app_coords, app_part_planes, struct_coords, struct_part_planes, cfg.app_guidance.num_part_clusters)
# Optimization Starts
app_labels = torch.from_numpy(app_labels.flatten()).cuda()
struct_labels = torch.from_numpy(struct_labels.flatten()).cuda()
point_feat1 = torch.from_numpy(point_feat1).cuda()
point_feat2 = torch.from_numpy(point_feat2).cuda()
struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True)
param_list = [struct_feats_params]
optimizer = torch.optim.AdamW(param_list, lr=cfg.app_guidance.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1)
best_loss = float('inf')
feats = None
image = generation_pipeline.preprocess_image(app_image)
cond = generation_pipeline.get_cond([image])
flow_model = generation_pipeline.models['slat_flow_model']
sampler_params={
"cfg_strength": cfg.app_guidance.cfg_strength,
"cfg_interval": cfg.app_guidance.cfg_interval,
}
t_seq = np.linspace(1, 0, cfg.app_guidance.steps + 1)
t_seq = cfg.app_guidance.rescale_t * t_seq / (1 + (cfg.app_guidance.rescale_t - 1) * t_seq)
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.app_guidance.steps))
std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda()
mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda()
log.info(f"Beginning guidance + flow sampling loop for {len(t_pairs)} steps...")
for iteration, (t, t_prev) in enumerate(t_pairs):
optimizer.zero_grad()
# Diffusion
struct_feats_params_clone = struct_feats_params.clone().cuda()
noise = sp.SparseTensor(
feats = struct_feats_params_clone,
coords = struct_coords.int(),
).cuda()
with torch.no_grad():
out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params)
sample = out.pred_x_prev
struct_feats_params.data = sample.feats
# Optimization
if iteration < len(t_pairs) - 1:
app_loss, num_labels = torch.tensor(0.0, requires_grad=True).cuda(), 0.0
for label in torch.unique(app_labels):
app_mask = (app_labels == label)
struct_mask = (struct_labels == label)
if app_mask.sum() == 0 or struct_mask.sum() == 0:
continue
# Appearance Loss
cos_sim = torch.matmul(point_feat2[struct_mask], point_feat1[app_mask].T)
cos_dist = (1 - cos_sim) / 2.
nearest = torch.argmin(cos_dist, dim=1)
matched = app_feats[app_mask][nearest]
curr_loss = F.mse_loss(struct_feats_params[struct_mask], matched)
app_loss += curr_loss
num_labels += 1
app_loss = cfg.app_guidance.loss_weight * (app_loss / num_labels)
total_loss = app_loss
total_loss.backward()
optimizer.step()
scheduler.step()
if (iteration == 0) or (iteration + 1) % cfg.log_every == 0:
message = f"Step: {iteration}, Appearance Loss: {app_loss.item():.4f}, Total Loss: {total_loss.item():.4f}"
log.info(message)
if total_loss < best_loss:
best_loss = total_loss.item()
feats = struct_feats_params.detach() * std + mean
# Decode SLAT
log.info("Decoding output SLAT...")
out_meshpath = osp.join(output_dir, 'out_app.glb')
out_gspath = osp.join(output_dir, 'out_gaussian_app.mp4')
generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath) |