Spaces:
Sleeping
Sleeping
| import os.path as osp | |
| import numpy as np | |
| import torch | |
| import utils3d | |
| import logging | |
| import third_party.TRELLIS.trellis.modules.sparse as sp | |
| from third_party.TRELLIS.trellis.pipelines import TrellisTextTo3DPipeline | |
| from lib.util import generation, partfield | |
| # Global logger | |
| log = logging.getLogger(__name__) | |
| def attn_cosine_sim(x, eps=1e-08): | |
| x = x[0] # TEMP: getting rid of redundant dimension, TBF | |
| norm1 = x.norm(dim=2, keepdim=True) | |
| factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps) | |
| sim_matrix = (x @ x.permute(0, 2, 1)) / factor | |
| return sim_matrix | |
| def optimize_self_similarity(cfg, app_text, output_dir): | |
| log.info("Starting self-similarity optimization...") | |
| generation_pipeline = TrellisTextTo3DPipeline.from_pretrained(cfg.trellis_text_model_name) | |
| generation_pipeline.cuda() | |
| # Load Structure Data | |
| struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0] | |
| struct_coords = torch.from_numpy(struct_coords).float().cuda() | |
| struct_coords = ((struct_coords + 0.5) * 64).long() | |
| zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device) | |
| struct_coords = torch.cat([zeros, struct_coords], dim=1) | |
| # Load partfield planes | |
| path = osp.join(output_dir, "partfield", "part_feat_struct_mesh_zup_batch_part_plane.npy") | |
| struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda() | |
| struct_labels = partfield.cluster_geoms(struct_coords, struct_part_planes, num_clusters=cfg.sim_guidance.num_part_clusters) | |
| # Optimization Starts... | |
| struct_labels = torch.from_numpy(struct_labels.flatten()).cuda() | |
| struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True) | |
| param_list = [struct_feats_params] | |
| optimizer = torch.optim.AdamW(param_list, lr=cfg.sim_guidance.learning_rate) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1) | |
| best_loss = float('inf') | |
| feats = None | |
| cond = generation_pipeline.get_cond([app_text]) | |
| flow_model = generation_pipeline.models['slat_flow_model'] | |
| sampler_params={ | |
| "cfg_strength": cfg.sim_guidance.cfg_strength, | |
| "cfg_interval": cfg.sim_guidance.cfg_interval, | |
| } | |
| t_seq = np.linspace(1, 0, cfg.sim_guidance.steps + 1) | |
| t_seq = cfg.sim_guidance.rescale_t * t_seq / (1 + (cfg.sim_guidance.rescale_t - 1) * t_seq) | |
| t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.sim_guidance.steps)) | |
| std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda() | |
| mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda() | |
| log.info(f"Beginning self-similarity guidance + flow sampling loop for {len(t_pairs)} steps...") | |
| for iteration, (t, t_prev) in enumerate(t_pairs): | |
| optimizer.zero_grad() | |
| # Diffusion | |
| struct_feats_params_clone = struct_feats_params.clone().cuda() | |
| noise = sp.SparseTensor( | |
| feats = struct_feats_params_clone, | |
| coords = struct_coords.int(), | |
| ).cuda() | |
| with torch.no_grad(): | |
| out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params) | |
| sample = out.pred_x_prev | |
| struct_feats_params.data = sample.feats | |
| # Optimization - Structure Loss | |
| if iteration < len(t_pairs) - 1: | |
| labels = struct_labels.view(-1,1) | |
| sim = attn_cosine_sim(struct_feats_params[None, None, ...])[0] | |
| mask = (labels == labels.T).float() | |
| logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0), device=struct_feats_params.device) | |
| mask = mask * logits_mask | |
| exp_sim = torch.exp(sim) * logits_mask | |
| numerator = (exp_sim * mask).sum(dim=1) | |
| denominator = exp_sim.sum(dim=1) | |
| struct_loss = -torch.log(numerator / (denominator + 1e-8)) | |
| struct_loss = struct_loss[mask.sum(dim=1) > 0].mean() | |
| total_loss = cfg.sim_guidance.loss_weight * struct_loss | |
| total_loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| if (iteration == 0) or (iteration + 1) % cfg.log_every == 0: | |
| message = f"Step: {iteration}, Structure Loss: {struct_loss.item():.4f}, Total Loss: {total_loss.item():.4f}" | |
| log.info(message) | |
| if total_loss < best_loss: | |
| best_loss = total_loss.item() | |
| feats = struct_feats_params.detach() * std + mean | |
| # Decode SLAT | |
| log.info("Decoding output SLAT...") | |
| out_meshpath = osp.join(output_dir, 'out_sim.glb') | |
| out_gspath = osp.join(output_dir, 'out_gaussian_sim.mp4') | |
| generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath) |