File size: 5,732 Bytes
382733a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
import utils3d
from PIL import Image
import logging

import third_party.TRELLIS.trellis.modules.sparse as sp
from third_party.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
from lib.util import partfield, generation

# Global logger
log = logging.getLogger(__name__)

def optimize_appearance(cfg, output_dir):
    log.info("Starting appearance optimization...")
    
    generation_pipeline = TrellisImageTo3DPipeline.from_pretrained(cfg.trellis_img_model_name)
    generation_pipeline.cuda()
    
    # load appearance and structure data
    path = osp.join(output_dir, 'latents', cfg.latent_name, "appearance.npz")
    data = np.load(path)
    app_feats = torch.from_numpy(data['feats']).cuda()
    app_coords = torch.from_numpy(data['coords']).cuda()
    
    struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0]
    struct_coords = torch.from_numpy(struct_coords).float().cuda()
    struct_coords = ((struct_coords + 0.5) * 64).long()
    
    app_image = Image.open(osp.join(output_dir, 'app_image.png')).convert('RGB')
    
    zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device)
    struct_coords = torch.cat([zeros, struct_coords], dim=1)
    
    # Load partfield planes
    path = osp.join(output_dir, 'partfield', 'part_feat_struct_mesh_zup_batch_part_plane.npy')
    struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()

    path = osp.join(output_dir, 'partfield', 'part_feat_app_mesh_zup_batch_part_plane.npy')
    app_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()

    app_labels, struct_labels, point_feat1, point_feat2 = partfield.cosegment_part(app_coords, app_part_planes, struct_coords, struct_part_planes, cfg.app_guidance.num_part_clusters)
        
    # Optimization Starts
    app_labels = torch.from_numpy(app_labels.flatten()).cuda()
    struct_labels = torch.from_numpy(struct_labels.flatten()).cuda()

    point_feat1 = torch.from_numpy(point_feat1).cuda()
    point_feat2 = torch.from_numpy(point_feat2).cuda()

    struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True)

    param_list = [struct_feats_params]
    optimizer = torch.optim.AdamW(param_list, lr=cfg.app_guidance.learning_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1)

    best_loss = float('inf')
    feats = None

    image = generation_pipeline.preprocess_image(app_image)
    cond = generation_pipeline.get_cond([image])

    flow_model = generation_pipeline.models['slat_flow_model']

    sampler_params={
        "cfg_strength": cfg.app_guidance.cfg_strength,
        "cfg_interval": cfg.app_guidance.cfg_interval,
    }

    t_seq = np.linspace(1, 0, cfg.app_guidance.steps + 1)
    t_seq = cfg.app_guidance.rescale_t * t_seq / (1 + (cfg.app_guidance.rescale_t - 1) * t_seq)
    t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.app_guidance.steps))

    std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda()
    mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda()
    
    log.info(f"Beginning guidance + flow sampling loop for {len(t_pairs)} steps...")
    for iteration, (t, t_prev) in enumerate(t_pairs):
        optimizer.zero_grad()

        # Diffusion
        struct_feats_params_clone = struct_feats_params.clone().cuda()
        noise = sp.SparseTensor(
                    feats = struct_feats_params_clone,
                    coords = struct_coords.int(),
                    ).cuda()
    
        with torch.no_grad():
            out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params)

        sample = out.pred_x_prev
        struct_feats_params.data = sample.feats

        # Optimization
        if iteration < len(t_pairs) - 1:
            app_loss, num_labels = torch.tensor(0.0, requires_grad=True).cuda(), 0.0
            for label in torch.unique(app_labels):
                app_mask = (app_labels == label)
                struct_mask = (struct_labels == label)
                
                if app_mask.sum() == 0 or struct_mask.sum() == 0:
                    continue
                
                # Appearance Loss
                cos_sim = torch.matmul(point_feat2[struct_mask], point_feat1[app_mask].T)
                cos_dist = (1 - cos_sim) / 2.
                nearest = torch.argmin(cos_dist, dim=1)
                
                matched = app_feats[app_mask][nearest]
                curr_loss = F.mse_loss(struct_feats_params[struct_mask], matched)
                
                app_loss += curr_loss
                num_labels += 1

            app_loss = cfg.app_guidance.loss_weight * (app_loss / num_labels)

            total_loss = app_loss
            
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            if (iteration == 0) or (iteration + 1) % cfg.log_every == 0:
                message = f"Step: {iteration}, Appearance Loss: {app_loss.item():.4f}, Total Loss: {total_loss.item():.4f}"
                log.info(message)

            if total_loss < best_loss:
                best_loss = total_loss.item()
                feats = struct_feats_params.detach() * std + mean

    # Decode SLAT
    log.info("Decoding output SLAT...")
    out_meshpath = osp.join(output_dir,  'out_app.glb')
    out_gspath = osp.join(output_dir,  'out_gaussian_app.mp4')
    generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath)