pixel3dmms / src /pixel3dmm /lightning /p3dmm_system.py
alexnasa's picture
Upload 66 files
cf92dec verified
from PIL import Image, ImageDraw
import os
import torch
import numpy as np
import pytorch_lightning as L
import torch.nn as nn
from pixel3dmm.lightning.utils import CosineWarmupScheduler, WarmupScheduler
from pixel3dmm.lightning.p3dmm_network import Network
from pixel3dmm import env_paths
def fov_to_ixt(fov, reso=512):
ixt = torch.eye(3).float().unsqueeze(0).repeat(fov.shape[0], 1, 1).to(fov.device)
ixt[:, 0, 2] = reso / 2
ixt[:, 1, 2] = reso / 2
focal = .5 * reso / torch.tan(.5 * fov)
ixt[:, 0, 0] = focal
ixt[:, 1, 1] = focal
return ixt
def batch_rodrigues(
rot_vecs: torch.Tensor,
epsilon: float = 1e-8,
) -> torch.Tensor:
''' Calculates the rotation matrices for a batch of rotation vectors
Parameters
----------
rot_vecs: torch.tensor Nx3
array of N axis-angle vectors
Returns
-------
R: torch.tensor Nx3x3
The rotation matrices for the given axis-angle parameters
'''
batch_size = rot_vecs.shape[0]
device, dtype = rot_vecs.device, rot_vecs.dtype
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
rot_dir = rot_vecs / angle
cos = torch.unsqueeze(torch.cos(angle), dim=1)
sin = torch.unsqueeze(torch.sin(angle), dim=1)
# Bx1 arrays
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
.view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
return rot_mat
def pad_to_3_channels(img):
if img.shape[-1] == 3:
return img
elif img.shape[-1] == 1:
return np.concatenate([img, np.zeros_like(img[..., :1]), np.zeros_like(img[..., :1])], axis=-1)
elif img.shape[-1] == 2:
return np.concatenate([img, np.zeros_like(img[..., :1])], axis=-1)
else:
raise ValueError('too many dimensions in prediction type!')
class system(L.LightningModule):
def __init__(self, cfg):
super().__init__()
self.glctx = None
self.cfg = cfg
self.net = Network(cfg)
vertex_weight_mask = np.load(f'{env_paths.VERTEX_WEIGHT_MASK}')
self.register_buffer('vertex_weight_mask', torch.from_numpy(vertex_weight_mask).float())
self.validation_step_outputs = []
self.validation_step_outputs_per_dataset = []
self.dataset_types = [
'facescape',
'nphm',
'ava',
]
self.do_eval = True
self.alpha = 1.0
self.save_hyperparameters()
self.loss_weights = {
'albedo': 1.0, # 1.0/0.13,
'depth': 1.0,
'pos_map': 1.0, # 1.0/0.0006,
'pos_map_can': 1.0, # 1.0/0.0006,
'normals': 0.1, # TODO achtung #1.0/0.03,
'normals_can': 1.0, # 1.0/0.03,
'uv_map': 10.0, # 1.0/0.001,
'nocs': 1.0, # 1.0/0.0006,
}
def training_step(self, batch, batch_idx):
output, conf = self.net(batch)
B = output[list(output.keys())[0]].shape[0]
V = output[list(output.keys())[0]].shape[1]
c_map = None
losses = {}
if 'normals' in self.cfg.model.prediction_type:
gt_normals = batch['normals'].permute(0, 1, 4, 2, 3)
if conf is None:
losses['normals'] = (batch['tar_msk'].unsqueeze(2) * (gt_normals - output['normals'])).abs().mean()
else:
losses['normals'] = (batch['tar_msk'].unsqueeze(2) * (
c_map * (gt_normals - output['normals']) - self.alpha * torch.log(c_map))).abs().mean()
if self.cfg.model.pred_disentangled:
gt_normals_can = batch['normals_can'].permute(0, 1, 4, 2, 3)
if conf is None:
losses['normals_can'] = (
batch['tar_msk'].unsqueeze(2) * (gt_normals_can - output['normals_can'])).abs().mean()
else:
losses['normals_can'] = (batch['tar_msk'].unsqueeze(2) * (
c_map * (gt_normals_can - output['normals_can']) - self.alpha * torch.log(
c_map))).abs().mean()
for prediction_type in ['uv_map', 'depth', 'nocs']:
if prediction_type in self.cfg.model.prediction_type:
weight_mask = torch.ones_like(output[prediction_type])
if prediction_type == 'uv_map' or (prediction_type == 'nocs'): # ATTENTION: only for nocs?
weight_mask = batch['uv_masks'].unsqueeze(2).float() + 0.2
gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
if conf is None:
losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
gt_pos_map - output[prediction_type])).abs().mean()
else:
losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
c_map * (gt_pos_map - output[prediction_type]) - self.alpha * torch.log(
c_map))).abs().mean()
total_loss = 0
loss = 0
for k in losses.keys():
if k in self.loss_weights:
loss += self.loss_weights[k] * losses[k]
else:
loss += losses[k]
self.log(f'train/loss', loss.item(), prog_bar=False)
# for prediction_type in self.cfg.model.prediction_type:
for k in losses.keys():
if k in self.cfg.model.prediction_type:
self.log(f'train/loss_{k}', losses[k])
if self.cfg.model.pred_disentangled:
for k in losses.keys():
if k[:-4] in self.cfg.model.prediction_type:
self.log(f'train/loss_{k}', losses[k])
self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
do_vis = (0 == self.trainer.global_step % 300) if os.path.exists('/mnt/rohan') else (
0 == self.trainer.global_step % 3000)
if do_vis and (self.trainer.local_rank == 0):
output, conf = self.net(batch)
self.vis_results({k: v.detach() for (k, v) in output.items()}, conf, batch, prex='train')
self.do_eval = True
torch.cuda.empty_cache()
return loss
def optimizer_step(
self,
*args, **kwargs
):
"""
Skipping updates in case of unstable gradients
https://github.com/Lightning-AI/lightning/issues/4956
"""
valid_gradients = True
grads = [
param.grad.detach().flatten()
for param in self.parameters()
if param.grad is not None
]
if len(grads) > 0:
norm = torch.cat(grads).norm()
self.log(f'grad/norm', norm.item(), prog_bar=False) # , sync_dist=True)
if (norm > 10000 and self.global_step > 20 or torch.isnan(norm)):
valid_gradients = False
if not valid_gradients:
print(
f'detected inf or nan values in gradients. not updating model parameters, OTHER FUNCTION threshold: {10000}, value: {norm.item()}')
self.zero_grad()
for param in self.parameters():
param.grad = None
L.LightningModule.optimizer_step(self, *args, **kwargs)
def validation_step(self, batch, batch_idx):
self.net.eval()
output, conf = self.net(batch)
B = output[list(output.keys())[0]].shape[0]
V = output[list(output.keys())[0]].shape[1]
loss_dict = {}
dataset_indices = {}
val_losses = {}
for prediction_type in ['uv_map', 'depth', 'nocs']:
if prediction_type in self.cfg.model.prediction_type:
gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
weight_mask = torch.ones_like(output[prediction_type])
if prediction_type == 'uv_map' or (prediction_type == 'nocs'): # ATTENTION: only for nocs?
weight_mask = batch['uv_masks'].unsqueeze(2).float() + 0.2
val_losses[prediction_type] = (weight_mask * batch['tar_msk'].unsqueeze(2) * (
gt_pos_map - output[prediction_type])).abs().mean()
loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
if 'normals' in self.cfg.model.prediction_type:
prediction_type = 'normals'
gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
val_losses[prediction_type] = (
batch['tar_msk'].unsqueeze(2) * (gt_pos_map - output[prediction_type])).abs().mean()
loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
if self.cfg.model.pred_disentangled:
prediction_type = 'normals_can'
gt_pos_map = batch[prediction_type].permute(0, 1, 4, 2, 3)
val_losses[prediction_type] = (
batch['tar_msk'].unsqueeze(2) * (gt_pos_map - output[prediction_type])).abs().mean()
loss_dict[f'loss/{prediction_type}'] = val_losses[prediction_type].item()
# if self.cfg.model.prediction_type == 'depth_si':
# loss, pred_scale, target_scale = simae2_loss(output, batch['depth'].permute(0, 1, 4, 2, 3), batch['tar_msk'].unsqueeze(2), c_map=c_map, alpha=self.alpha)
# self.validation_step_outputs.append({'loss': loss.item()})
val_loss = 0
for prediction_type in self.cfg.model.prediction_type:
val_loss += self.loss_weights[prediction_type] * val_losses[prediction_type]
loss_dict['loss/total'] = val_loss.item()
self.validation_step_outputs.append(loss_dict)
#print('GLOBAL_STEP:', self.trainer.global_step)
if self.do_eval and self.trainer.local_rank == 0:
output, conf = self.net(batch)
if conf is not None:
conf = conf.detach()
tmp_dict = {k: v.detach() for (k, v) in output.items()}
self.vis_results(tmp_dict, conf, batch, prex='val')
self.do_eval = False
torch.cuda.empty_cache()
return val_loss
def on_validation_epoch_end(self):
# for key in keys:
# prog_bar = True if key in ['psnr','mask','depth'] else False
metric_mean = np.stack([np.array(x['loss/total']) for x in self.validation_step_outputs]).mean()
self.log(f'val/loss', metric_mean, prog_bar=False, sync_dist=True)
if self.net.n_facial_components == 0:
for prediction_type in self.cfg.model.prediction_type:
metric_mean_pred_type = np.stack(
[np.array(x[f'loss/{prediction_type}']) for x in self.validation_step_outputs]).mean()
self.log(f'val/loss_{prediction_type}', metric_mean_pred_type, sync_dist=True)
for dataset_type in self.dataset_types:
for loss_type in self.validation_step_outputs[0].keys():
content = [np.array(x[dataset_type][loss_type]) for x in self.validation_step_outputs_per_dataset if loss_type in x[dataset_type]]
if len(content) > 0:
metric_mean = np.nanmean(np.stack(content))
self.log(f'val_{dataset_type}/{loss_type}', metric_mean, sync_dist=True)
self.validation_step_outputs.clear() # free memory
torch.cuda.empty_cache()
def vis_results(self, output, conf, batch, prex):
out_folder = f'{self.cfg.reconstruction_folder}/{prex}_{self.trainer.global_step}/'
os.makedirs(out_folder, exist_ok=True)
output_gpu = {k: v for k, v in output.items()}
output = {k: v.cpu() for k, v in output.items()}
if self.net.n_facial_components == 0:
output_rows = {}
for predictiont_type in ['normals', 'albedo', 'uv_map', 'nocs']:
if predictiont_type in self.cfg.model.prediction_type:
output_rows[predictiont_type] = (batch['tar_msk'][..., None].float() * batch[predictiont_type]).permute(0, 1, 4, 2, 3).detach().cpu()
if predictiont_type in self.cfg.model.prediction_type and predictiont_type == 'normals' and self.cfg.model.pred_disentangled:
output_rows['normals_can'] = (batch['tar_msk'][..., None].float() * batch['normals_can']).permute(0, 1, 4, 2, 3).detach().cpu()
gt_rgb = batch['tar_rgb'].permute(0, 1, 4, 2, 3).detach().cpu()
for i_batch in range(output_rows[self.cfg.model.prediction_type[0]].shape[0]):
modalities = []
prediction_types = self.cfg.model.prediction_type.copy() # ['pos_map', 'normals', 'albedo', 'uv_map']
if self.cfg.model.pred_disentangled and "pos_map" in prediction_types:
prediction_types.append('pos_map_can')
if self.cfg.model.pred_disentangled and "normals" in prediction_types:
prediction_types.append('normals_can')
if self.cfg.model.pred_disentangled and "uv_map" in prediction_types:
prediction_types.append('disps')
for prediction_type in prediction_types:
rows = []
for i_view in range(output_rows[prediction_type].shape[1]):
with torch.no_grad():
mini = min(output_rows[prediction_type][i_batch, i_view].min().item(),
output[prediction_type][i_batch, i_view].min().item())
tmp_gt_pos_map = output_rows[prediction_type][i_batch, i_view].clone() - mini
tmp_output = output[prediction_type][i_batch, i_view].clone() - mini
maxi = max(tmp_gt_pos_map.max().item(), tmp_output.max().item())
tmp_gt_pos_map = tmp_gt_pos_map / maxi
tmp_output = tmp_output / maxi
catted = [
gt_rgb[i_batch, i_view].permute(1, 2, 0).detach().cpu().numpy(),
pad_to_3_channels(
(batch['tar_msk'][i_batch, i_view].cpu() * tmp_gt_pos_map.cpu()).permute(1, 2,
0).detach().cpu().numpy()),
pad_to_3_channels(tmp_output.permute(1, 2, 0).detach().cpu().float().numpy()),
]
if conf is not None:
mini_conf = conf[i_batch, i_view].min()
tmp_conf = conf[i_batch, i_view].clone() - mini_conf
maxi_conf = tmp_conf.max()
tmp_conf = tmp_conf / maxi_conf
catted.append(
pad_to_3_channels(tmp_conf.permute(1, 2, 0).detach().cpu().float().numpy()))
catted = (np.concatenate(catted, axis=1) * 255).astype(np.uint8)
rows.append(catted)
modalities.append(np.concatenate(rows, axis=0))
catted = Image.fromarray(np.concatenate(modalities, axis=0))
scene_name = batch['meta']['scene'][i_batch]
catted.save(f'{out_folder}/{scene_name}.png') # , quality=90)
keys = list(output.keys())
for k in keys:
del output[k]
del output
del gt_rgb
keys = list(output_rows.keys())
for k in keys:
del output_rows[k]
del output_rows
torch.cuda.empty_cache()
# pll.show()
def num_steps(self) -> int:
"""Get number of steps"""
# Accessing _data_source is flaky and might break
dataset = self.trainer.fit_loop._data_source.dataloader()
dataset_size = len(dataset)
num_devices = max(1, self.trainer.num_devices)
num_steps = dataset_size * self.trainer.max_epochs * self.cfg.train.limit_train_batches // (
self.trainer.accumulate_grad_batches * num_devices)
return int(num_steps)
def configure_optimizers(self):
decay_params, no_decay_params = [], []
invalid_params = []
all_backbone_params = []
all_non_backbone_params = []
backbone_params = []
backbone_params_no_decay = []
# add all bias and LayerNorm params to no_decay_params
for name, module in self.named_modules():
if name == 'flame' or name == 'flame_generic':
invalid_params.extend([p for p in module.parameters()])
else:
if isinstance(module, nn.LayerNorm):
if 'img_encoder' in name:
backbone_params_no_decay.extend([p for p in module.parameters()])
else:
no_decay_params.extend([p for p in module.parameters()])
elif hasattr(module, 'bias') and module.bias is not None:
if 'img_encoder' in name:
backbone_params_no_decay.append(module.bias)
else:
no_decay_params.append(module.bias)
if 'img_encoder' in name:
all_backbone_params.extend([p for p in module.parameters()])
else:
all_non_backbone_params.extend([p for p in module.parameters()])
# add remaining parameters to decay_params
_no_decay_ids = set(map(id, no_decay_params))
_all_backbone_ids = set(map(id, all_backbone_params))
_all_non_backbone_ids = set(map(id, all_non_backbone_params))
_backbone_no_decay_ids = set(map(id, backbone_params_no_decay))
_invalid_ids = set(map(id, invalid_params))
decay_params = [p for p in self.parameters() if
id(p) not in _no_decay_ids and id(p) not in _all_backbone_ids and id(p) not in _invalid_ids]
decay_params_backbone = [p for p in self.parameters() if
id(p) not in _backbone_no_decay_ids and id(p) not in _all_non_backbone_ids and id(
p) not in _invalid_ids]
no_decay_params = [p for p in no_decay_params if id(p) not in _invalid_ids]
no_decay_params_backbone = [p for p in backbone_params_no_decay if id(p) not in _invalid_ids]
# filter out parameters with no grad
decay_params = list(filter(lambda p: p.requires_grad, decay_params))
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
decay_params_backbone = list(filter(lambda p: p.requires_grad, decay_params_backbone))
no_decay_params_backbone = list(filter(lambda p: p.requires_grad, no_decay_params_backbone))
# Optimizer
opt_groups = [
{'params': decay_params, 'weight_decay': self.cfg.train.weight_decay, 'lr': self.cfg.train.lr},
{'params': decay_params_backbone, 'weight_decay': self.cfg.train.weight_decay,
'lr': self.cfg.train.lr_backbone},
{'params': no_decay_params, 'weight_decay': 0.0, 'lr': self.cfg.train.lr},
{'params': no_decay_params_backbone, 'weight_decay': 0.0, 'lr': self.cfg.train.lr_backbone},
]
optimizer = torch.optim.AdamW(
opt_groups,
betas=(self.cfg.train.beta1, self.cfg.train.beta2),
)
total_global_batches = self.num_steps()
scheduler = CosineWarmupScheduler(
optimizer=optimizer,
warmup_iters=self.cfg.train.warmup_iters,
max_iters=total_global_batches,
)
return {"optimizer": optimizer,
"lr_scheduler": {
'scheduler': scheduler,
'interval': 'step' # or 'epoch' for epoch-level updates
}}