| import torch |
| import torch.nn as nn |
|
|
| from typing import Optional |
|
|
| from Utils import return_edges, split_state, central_diff, masked_mean |
|
|
| H36M17_EDGES = return_edges()
|
|
|
|
|
| def kinematic_consistency_losses(
|
| s_pred: torch.Tensor,
|
| s_target: torch.Tensor,
|
| fps: float,
|
| mask_btj: Optional[torch.Tensor] = None,
|
| reduction: str = "mean",
|
| eps: float = 1e-6,
|
| ):
|
| dt = 1.0 / fps
|
| p, v, a, j = split_state(s_pred)
|
| _, v_t, a_t, j_t = split_state(s_target)
|
|
|
| dp = central_diff(p, dt)
|
| dv = central_diff(v, dt)
|
| da = central_diff(a, dt)
|
|
|
| pv_err = v - dp
|
| va_err = a - dv
|
| aj_err = j - da
|
|
|
| pv = (pv_err ** 2).sum(dim=-1) / ((v_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
|
| va = (va_err ** 2).sum(dim=-1) / ((a_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
|
| aj = (aj_err ** 2).sum(dim=-1) / ((j_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
|
|
|
| if reduction == "mean":
|
| return {
|
| "pv": masked_mean(pv, mask_btj),
|
| "va": masked_mean(va, mask_btj),
|
| "aj": masked_mean(aj, mask_btj),
|
| }
|
| elif reduction == "none":
|
| return {
|
| "pv": pv,
|
| "va": va,
|
| "aj": aj,
|
| }
|
| else:
|
| raise ValueError(f"Unknown reduction: {reduction}")
|
|
|
|
|
| def jerk_reg_loss(
|
| s_pred: torch.Tensor,
|
| s_target: torch.Tensor,
|
| mask_btj: Optional[torch.Tensor] = None,
|
| fps: float = 30.0,
|
| eps: float = 1e-6,
|
| ):
|
| _, _, _, j_pred = split_state(s_pred)
|
| _, _, _, j_t = split_state(s_target)
|
|
|
| target_scale = (j_t ** 2).sum(dim=-1)
|
| target_scale = masked_mean(target_scale, mask_btj).detach().clamp_min(eps)
|
|
|
| pred_mag = (j_pred ** 2).sum(dim=-1)
|
| pred_mag = masked_mean(pred_mag, mask_btj)
|
|
|
| return pred_mag / target_scale
|
|
|
|
|
| def state_reconstruction_loss(
|
| s_pred: torch.Tensor,
|
| s_target: torch.Tensor,
|
| mask_btj: Optional[torch.Tensor] = None,
|
| w_p: float = 1.0,
|
| w_v: float = 1.0,
|
| w_a: float = 0.5,
|
| w_j: float = 0.25,
|
| ):
|
| p_pred, v_pred, a_pred, j_pred = split_state(s_pred)
|
| p_tgt, v_tgt, a_tgt, j_tgt = split_state(s_target)
|
|
|
| p_err = ((p_pred - p_tgt) ** 2).sum(dim=-1)
|
| v_err = ((v_pred - v_tgt) ** 2).sum(dim=-1)
|
| a_err = ((a_pred - a_tgt) ** 2).sum(dim=-1)
|
| j_err = ((j_pred - j_tgt) ** 2).sum(dim=-1)
|
|
|
| loss_p = masked_mean(p_err, mask_btj)
|
| loss_v = masked_mean(v_err, mask_btj)
|
| loss_a = masked_mean(a_err, mask_btj)
|
| loss_j = masked_mean(j_err, mask_btj)
|
|
|
| total = (
|
| w_p * loss_p +
|
| w_v * loss_v +
|
| w_a * loss_a +
|
| w_j * loss_j
|
| )
|
|
|
| return total, {
|
| "rec_p": loss_p,
|
| "rec_v": loss_v,
|
| "rec_a": loss_a,
|
| "rec_j": loss_j,
|
| }
|
|
|
|
|
| def bone_length_loss(
|
| s_pred: torch.Tensor,
|
| s_target: torch.Tensor,
|
| mask_btj: Optional[torch.Tensor] = None,
|
| edges=H36M17_EDGES,
|
| ):
|
| p_pred, _, _, _ = split_state(s_pred)
|
| p_tgt, _, _, _ = split_state(s_target)
|
|
|
| pred_lengths = []
|
| tgt_lengths = []
|
|
|
| for i, j in edges:
|
| pred_len = torch.sqrt(((p_pred[:, :, i] - p_pred[:, :, j]) ** 2).sum(dim=-1) + 1e-8)
|
| tgt_len = torch.sqrt(((p_tgt[:, :, i] - p_tgt[:, :, j]) ** 2).sum(dim=-1) + 1e-8)
|
| pred_lengths.append(pred_len)
|
| tgt_lengths.append(tgt_len)
|
|
|
| pred_lengths = torch.stack(pred_lengths, dim=-1)
|
| tgt_lengths = torch.stack(tgt_lengths, dim=-1)
|
|
|
| bone_err = (pred_lengths - tgt_lengths) ** 2
|
|
|
| if mask_btj is not None:
|
| |
| edge_masks = []
|
| for i, j in edges:
|
| edge_masks.append(mask_btj[:, :, i] * mask_btj[:, :, j])
|
| edge_mask = torch.stack(edge_masks, dim=-1)
|
| else:
|
| edge_mask = None
|
|
|
| return masked_mean(bone_err, edge_mask)
|
|
|
|
|
| class PINNPretrainLoss(nn.Module):
|
| """
|
| Step 1 only:
|
| physics-aware encoder pretraining loss
|
| """
|
| def __init__(
|
| self,
|
| fps: float = 30.0,
|
| w_state_rec: float = 1.0,
|
| w_pv: float = 0.0,
|
| w_va: float = 0.0,
|
| w_aj: float = 0.0,
|
| w_jerk: float = 0.001,
|
| w_bone: float = 0.1,
|
| ):
|
| super().__init__()
|
| self.fps = fps
|
| self.w_state_rec = w_state_rec
|
| self.w_pv = w_pv
|
| self.w_va = w_va
|
| self.w_aj = w_aj
|
| self.w_jerk = w_jerk
|
| self.w_bone = w_bone
|
|
|
| def forward(
|
| self,
|
| s_pred: torch.Tensor,
|
| s_target: torch.Tensor,
|
| mask_btj: Optional[torch.Tensor] = None,
|
| ):
|
| loss_dict = {}
|
|
|
| kin = kinematic_consistency_losses(
|
| s_pred=s_pred,
|
| s_target=s_target,
|
| fps=self.fps,
|
| mask_btj=mask_btj,
|
| reduction="mean",
|
| )
|
|
|
| loss_pv = kin["pv"]
|
| loss_va = kin["va"]
|
| loss_aj = kin["aj"]
|
|
|
| loss_rec, rec_dict = state_reconstruction_loss(
|
| s_pred=s_pred,
|
| s_target=s_target,
|
| mask_btj=mask_btj,
|
| w_p=2.0,
|
| w_v=0.5,
|
| w_a=0.25,
|
| w_j=0.1,
|
| )
|
|
|
| loss_bone = bone_length_loss(
|
| s_pred=s_pred,
|
| s_target=s_target,
|
| mask_btj=mask_btj,
|
| )
|
|
|
| if self.w_jerk > 0.0:
|
| loss_jerk = jerk_reg_loss(
|
| s_pred=s_pred,
|
| s_target=s_target,
|
| mask_btj=mask_btj,
|
| fps=self.fps,
|
| )
|
| else:
|
| loss_jerk = s_pred.new_zeros(())
|
|
|
| total = (
|
| self.w_state_rec * loss_rec +
|
| self.w_pv * loss_pv +
|
| self.w_va * loss_va +
|
| self.w_aj * loss_aj +
|
| self.w_jerk * loss_jerk +
|
| self.w_bone * loss_bone
|
| )
|
|
|
| loss_dict["loss_state_rec"] = loss_rec
|
| loss_dict["loss_pv"] = loss_pv
|
| loss_dict["loss_va"] = loss_va
|
| loss_dict["loss_aj"] = loss_aj
|
| loss_dict["loss_bone"] = loss_bone
|
| loss_dict["loss_jerk_reg"] = loss_jerk
|
| loss_dict["loss_total"] = total
|
|
|
| loss_dict.update({f"loss_{k}": v for k, v in rec_dict.items()})
|
|
|
| return loss_dict |
|
|