import os.path as osp import numpy as np import torch import torch.nn.functional as F import utils3d from PIL import Image import logging import third_party.TRELLIS.trellis.modules.sparse as sp from third_party.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline from lib.util import partfield, generation # Global logger log = logging.getLogger(__name__) def optimize_appearance(cfg, output_dir): log.info("Starting appearance optimization...") generation_pipeline = TrellisImageTo3DPipeline.from_pretrained(cfg.trellis_img_model_name) generation_pipeline.cuda() # load appearance and structure data path = osp.join(output_dir, 'latents', cfg.latent_name, "appearance.npz") data = np.load(path) app_feats = torch.from_numpy(data['feats']).cuda() app_coords = torch.from_numpy(data['coords']).cuda() struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0] struct_coords = torch.from_numpy(struct_coords).float().cuda() struct_coords = ((struct_coords + 0.5) * 64).long() app_image = Image.open(osp.join(output_dir, 'app_image.png')).convert('RGB') zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device) struct_coords = torch.cat([zeros, struct_coords], dim=1) # Load partfield planes path = osp.join(output_dir, 'partfield', 'part_feat_struct_mesh_zup_batch_part_plane.npy') struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda() path = osp.join(output_dir, 'partfield', 'part_feat_app_mesh_zup_batch_part_plane.npy') app_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda() app_labels, struct_labels, point_feat1, point_feat2 = partfield.cosegment_part(app_coords, app_part_planes, struct_coords, struct_part_planes, cfg.app_guidance.num_part_clusters) # Optimization Starts app_labels = torch.from_numpy(app_labels.flatten()).cuda() struct_labels = torch.from_numpy(struct_labels.flatten()).cuda() point_feat1 = torch.from_numpy(point_feat1).cuda() point_feat2 = torch.from_numpy(point_feat2).cuda() struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True) param_list = [struct_feats_params] optimizer = torch.optim.AdamW(param_list, lr=cfg.app_guidance.learning_rate) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1) best_loss = float('inf') feats = None image = generation_pipeline.preprocess_image(app_image) cond = generation_pipeline.get_cond([image]) flow_model = generation_pipeline.models['slat_flow_model'] sampler_params={ "cfg_strength": cfg.app_guidance.cfg_strength, "cfg_interval": cfg.app_guidance.cfg_interval, } t_seq = np.linspace(1, 0, cfg.app_guidance.steps + 1) t_seq = cfg.app_guidance.rescale_t * t_seq / (1 + (cfg.app_guidance.rescale_t - 1) * t_seq) t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.app_guidance.steps)) std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda() mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda() log.info(f"Beginning guidance + flow sampling loop for {len(t_pairs)} steps...") for iteration, (t, t_prev) in enumerate(t_pairs): optimizer.zero_grad() # Diffusion struct_feats_params_clone = struct_feats_params.clone().cuda() noise = sp.SparseTensor( feats = struct_feats_params_clone, coords = struct_coords.int(), ).cuda() with torch.no_grad(): out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params) sample = out.pred_x_prev struct_feats_params.data = sample.feats # Optimization if iteration < len(t_pairs) - 1: app_loss, num_labels = torch.tensor(0.0, requires_grad=True).cuda(), 0.0 for label in torch.unique(app_labels): app_mask = (app_labels == label) struct_mask = (struct_labels == label) if app_mask.sum() == 0 or struct_mask.sum() == 0: continue # Appearance Loss cos_sim = torch.matmul(point_feat2[struct_mask], point_feat1[app_mask].T) cos_dist = (1 - cos_sim) / 2. nearest = torch.argmin(cos_dist, dim=1) matched = app_feats[app_mask][nearest] curr_loss = F.mse_loss(struct_feats_params[struct_mask], matched) app_loss += curr_loss num_labels += 1 app_loss = cfg.app_guidance.loss_weight * (app_loss / num_labels) total_loss = app_loss total_loss.backward() optimizer.step() scheduler.step() if (iteration == 0) or (iteration + 1) % cfg.log_every == 0: message = f"Step: {iteration}, Appearance Loss: {app_loss.item():.4f}, Total Loss: {total_loss.item():.4f}" log.info(message) if total_loss < best_loss: best_loss = total_loss.item() feats = struct_feats_params.detach() * std + mean # Decode SLAT log.info("Decoding output SLAT...") out_meshpath = osp.join(output_dir, 'out_app.glb') out_gspath = osp.join(output_dir, 'out_gaussian_app.mp4') generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath)