File size: 5,029 Bytes
382733a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os.path as osp
import numpy as np
import torch
import utils3d
import logging

import third_party.TRELLIS.trellis.modules.sparse as sp
from third_party.TRELLIS.trellis.pipelines import TrellisTextTo3DPipeline
from lib.util import generation, partfield

# Global logger
log = logging.getLogger(__name__)

def attn_cosine_sim(x, eps=1e-08):
    x = x[0]  # TEMP: getting rid of redundant dimension, TBF
    norm1 = x.norm(dim=2, keepdim=True)
    factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
    sim_matrix = (x @ x.permute(0, 2, 1)) / factor
    return sim_matrix

def optimize_self_similarity(cfg, app_text, output_dir):
    log.info("Starting self-similarity optimization...")
    
    generation_pipeline = TrellisTextTo3DPipeline.from_pretrained(cfg.trellis_text_model_name)
    generation_pipeline.cuda()
    
    # Load Structure Data
    struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0]
    struct_coords = torch.from_numpy(struct_coords).float().cuda()
    struct_coords = ((struct_coords + 0.5) * 64).long()
    
    zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device)
    struct_coords = torch.cat([zeros, struct_coords], dim=1)
    
    # Load partfield planes
    path = osp.join(output_dir, "partfield", "part_feat_struct_mesh_zup_batch_part_plane.npy")
    struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()

    struct_labels = partfield.cluster_geoms(struct_coords, struct_part_planes, num_clusters=cfg.sim_guidance.num_part_clusters)

    # Optimization Starts...
    struct_labels = torch.from_numpy(struct_labels.flatten()).cuda()
    struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True)
    
    param_list = [struct_feats_params]
    optimizer = torch.optim.AdamW(param_list, lr=cfg.sim_guidance.learning_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1)
    
    best_loss = float('inf')
    feats = None
    
    cond = generation_pipeline.get_cond([app_text])
    
    flow_model = generation_pipeline.models['slat_flow_model']
    sampler_params={
        "cfg_strength": cfg.sim_guidance.cfg_strength,
        "cfg_interval": cfg.sim_guidance.cfg_interval,
    }

    t_seq = np.linspace(1, 0, cfg.sim_guidance.steps + 1)
    t_seq = cfg.sim_guidance.rescale_t * t_seq / (1 + (cfg.sim_guidance.rescale_t - 1) * t_seq)
    t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.sim_guidance.steps))
    
    std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda()
    mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda()

    log.info(f"Beginning self-similarity guidance + flow sampling loop for {len(t_pairs)} steps...")
    for iteration, (t, t_prev) in enumerate(t_pairs):
        optimizer.zero_grad()
        
        # Diffusion
        struct_feats_params_clone = struct_feats_params.clone().cuda()
        noise = sp.SparseTensor(
            feats = struct_feats_params_clone,
            coords = struct_coords.int(),
        ).cuda()
        
        with torch.no_grad():
            out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params)
            
        sample = out.pred_x_prev
        struct_feats_params.data = sample.feats
        
        # Optimization - Structure Loss
        if iteration < len(t_pairs) - 1:
            labels = struct_labels.view(-1,1)
            sim = attn_cosine_sim(struct_feats_params[None, None, ...])[0]
        
            mask = (labels == labels.T).float()
            
            logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0), device=struct_feats_params.device)
            mask = mask * logits_mask
            
            exp_sim = torch.exp(sim) * logits_mask
            numerator = (exp_sim * mask).sum(dim=1)
            denominator = exp_sim.sum(dim=1)
            
            struct_loss = -torch.log(numerator / (denominator + 1e-8))
            struct_loss = struct_loss[mask.sum(dim=1) > 0].mean()

            total_loss = cfg.sim_guidance.loss_weight * struct_loss
            total_loss.backward()
            optimizer.step()
            scheduler.step()
            
            if (iteration == 0) or (iteration + 1) % cfg.log_every == 0:
                message = f"Step: {iteration}, Structure Loss: {struct_loss.item():.4f}, Total Loss: {total_loss.item():.4f}"
                log.info(message)
                
            if total_loss < best_loss:
                best_loss = total_loss.item()
                feats = struct_feats_params.detach() * std + mean
    
    # Decode SLAT
    log.info("Decoding output SLAT...")
    out_meshpath = osp.join(output_dir,  'out_sim.glb')
    out_gspath = osp.join(output_dir,  'out_gaussian_sim.mp4')
    generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath)