GuideFlow3D / lib /opt /appearance.py
suvadityamuk's picture
feat: add initial files for space
382733a
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
import utils3d
from PIL import Image
import logging
import third_party.TRELLIS.trellis.modules.sparse as sp
from third_party.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
from lib.util import partfield, generation
# Global logger
log = logging.getLogger(__name__)
def optimize_appearance(cfg, output_dir):
log.info("Starting appearance optimization...")
generation_pipeline = TrellisImageTo3DPipeline.from_pretrained(cfg.trellis_img_model_name)
generation_pipeline.cuda()
# load appearance and structure data
path = osp.join(output_dir, 'latents', cfg.latent_name, "appearance.npz")
data = np.load(path)
app_feats = torch.from_numpy(data['feats']).cuda()
app_coords = torch.from_numpy(data['coords']).cuda()
struct_coords = utils3d.io.read_ply(osp.join(output_dir, 'voxels', 'struct_voxels.ply'))[0]
struct_coords = torch.from_numpy(struct_coords).float().cuda()
struct_coords = ((struct_coords + 0.5) * 64).long()
app_image = Image.open(osp.join(output_dir, 'app_image.png')).convert('RGB')
zeros = torch.zeros((struct_coords.size(0), 1), dtype=struct_coords.dtype, device=struct_coords.device)
struct_coords = torch.cat([zeros, struct_coords], dim=1)
# Load partfield planes
path = osp.join(output_dir, 'partfield', 'part_feat_struct_mesh_zup_batch_part_plane.npy')
struct_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()
path = osp.join(output_dir, 'partfield', 'part_feat_app_mesh_zup_batch_part_plane.npy')
app_part_planes = torch.from_numpy(np.load(path, allow_pickle=True)).cuda()
app_labels, struct_labels, point_feat1, point_feat2 = partfield.cosegment_part(app_coords, app_part_planes, struct_coords, struct_part_planes, cfg.app_guidance.num_part_clusters)
# Optimization Starts
app_labels = torch.from_numpy(app_labels.flatten()).cuda()
struct_labels = torch.from_numpy(struct_labels.flatten()).cuda()
point_feat1 = torch.from_numpy(point_feat1).cuda()
point_feat2 = torch.from_numpy(point_feat2).cuda()
struct_feats_params = torch.nn.Parameter(torch.randn((struct_coords.shape[0], cfg.flow_model_in_channels)), requires_grad=True)
param_list = [struct_feats_params]
optimizer = torch.optim.AdamW(param_list, lr=cfg.app_guidance.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1)
best_loss = float('inf')
feats = None
image = generation_pipeline.preprocess_image(app_image)
cond = generation_pipeline.get_cond([image])
flow_model = generation_pipeline.models['slat_flow_model']
sampler_params={
"cfg_strength": cfg.app_guidance.cfg_strength,
"cfg_interval": cfg.app_guidance.cfg_interval,
}
t_seq = np.linspace(1, 0, cfg.app_guidance.steps + 1)
t_seq = cfg.app_guidance.rescale_t * t_seq / (1 + (cfg.app_guidance.rescale_t - 1) * t_seq)
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(cfg.app_guidance.steps))
std = torch.tensor(generation_pipeline.slat_normalization['std'])[None].cuda()
mean = torch.tensor(generation_pipeline.slat_normalization['mean'])[None].cuda()
log.info(f"Beginning guidance + flow sampling loop for {len(t_pairs)} steps...")
for iteration, (t, t_prev) in enumerate(t_pairs):
optimizer.zero_grad()
# Diffusion
struct_feats_params_clone = struct_feats_params.clone().cuda()
noise = sp.SparseTensor(
feats = struct_feats_params_clone,
coords = struct_coords.int(),
).cuda()
with torch.no_grad():
out = generation_pipeline.slat_sampler.sample_once(flow_model, noise, t, t_prev, **cond, **sampler_params)
sample = out.pred_x_prev
struct_feats_params.data = sample.feats
# Optimization
if iteration < len(t_pairs) - 1:
app_loss, num_labels = torch.tensor(0.0, requires_grad=True).cuda(), 0.0
for label in torch.unique(app_labels):
app_mask = (app_labels == label)
struct_mask = (struct_labels == label)
if app_mask.sum() == 0 or struct_mask.sum() == 0:
continue
# Appearance Loss
cos_sim = torch.matmul(point_feat2[struct_mask], point_feat1[app_mask].T)
cos_dist = (1 - cos_sim) / 2.
nearest = torch.argmin(cos_dist, dim=1)
matched = app_feats[app_mask][nearest]
curr_loss = F.mse_loss(struct_feats_params[struct_mask], matched)
app_loss += curr_loss
num_labels += 1
app_loss = cfg.app_guidance.loss_weight * (app_loss / num_labels)
total_loss = app_loss
total_loss.backward()
optimizer.step()
scheduler.step()
if (iteration == 0) or (iteration + 1) % cfg.log_every == 0:
message = f"Step: {iteration}, Appearance Loss: {app_loss.item():.4f}, Total Loss: {total_loss.item():.4f}"
log.info(message)
if total_loss < best_loss:
best_loss = total_loss.item()
feats = struct_feats_params.detach() * std + mean
# Decode SLAT
log.info("Decoding output SLAT...")
out_meshpath = osp.join(output_dir, 'out_app.glb')
out_gspath = osp.join(output_dir, 'out_gaussian_app.mp4')
generation.decode_slat(generation_pipeline, feats, struct_coords, out_meshpath, out_gspath)