import torch import torch.nn as nn from typing import Optional from Utils import return_edges, split_state, central_diff, masked_mean H36M17_EDGES = return_edges() def kinematic_consistency_losses( s_pred: torch.Tensor, s_target: torch.Tensor, fps: float, mask_btj: Optional[torch.Tensor] = None, reduction: str = "mean", eps: float = 1e-6, ): dt = 1.0 / fps p, v, a, j = split_state(s_pred) _, v_t, a_t, j_t = split_state(s_target) dp = central_diff(p, dt) dv = central_diff(v, dt) da = central_diff(a, dt) pv_err = v - dp va_err = a - dv aj_err = j - da pv = (pv_err ** 2).sum(dim=-1) / ((v_t.detach() ** 2).sum(dim=-1).clamp_min(eps)) va = (va_err ** 2).sum(dim=-1) / ((a_t.detach() ** 2).sum(dim=-1).clamp_min(eps)) aj = (aj_err ** 2).sum(dim=-1) / ((j_t.detach() ** 2).sum(dim=-1).clamp_min(eps)) if reduction == "mean": return { "pv": masked_mean(pv, mask_btj), "va": masked_mean(va, mask_btj), "aj": masked_mean(aj, mask_btj), } elif reduction == "none": return { "pv": pv, "va": va, "aj": aj, } else: raise ValueError(f"Unknown reduction: {reduction}") def jerk_reg_loss( s_pred: torch.Tensor, s_target: torch.Tensor, mask_btj: Optional[torch.Tensor] = None, fps: float = 30.0, eps: float = 1e-6, ): _, _, _, j_pred = split_state(s_pred) _, _, _, j_t = split_state(s_target) target_scale = (j_t ** 2).sum(dim=-1) target_scale = masked_mean(target_scale, mask_btj).detach().clamp_min(eps) pred_mag = (j_pred ** 2).sum(dim=-1) pred_mag = masked_mean(pred_mag, mask_btj) return pred_mag / target_scale def state_reconstruction_loss( s_pred: torch.Tensor, s_target: torch.Tensor, mask_btj: Optional[torch.Tensor] = None, w_p: float = 1.0, w_v: float = 1.0, w_a: float = 0.5, w_j: float = 0.25, ): p_pred, v_pred, a_pred, j_pred = split_state(s_pred) p_tgt, v_tgt, a_tgt, j_tgt = split_state(s_target) p_err = ((p_pred - p_tgt) ** 2).sum(dim=-1) v_err = ((v_pred - v_tgt) ** 2).sum(dim=-1) a_err = ((a_pred - a_tgt) ** 2).sum(dim=-1) j_err = ((j_pred - j_tgt) ** 2).sum(dim=-1) loss_p = masked_mean(p_err, mask_btj) loss_v = masked_mean(v_err, mask_btj) loss_a = masked_mean(a_err, mask_btj) loss_j = masked_mean(j_err, mask_btj) total = ( w_p * loss_p + w_v * loss_v + w_a * loss_a + w_j * loss_j ) return total, { "rec_p": loss_p, "rec_v": loss_v, "rec_a": loss_a, "rec_j": loss_j, } def bone_length_loss( s_pred: torch.Tensor, s_target: torch.Tensor, mask_btj: Optional[torch.Tensor] = None, edges=H36M17_EDGES, ): p_pred, _, _, _ = split_state(s_pred) p_tgt, _, _, _ = split_state(s_target) pred_lengths = [] tgt_lengths = [] for i, j in edges: pred_len = torch.sqrt(((p_pred[:, :, i] - p_pred[:, :, j]) ** 2).sum(dim=-1) + 1e-8) tgt_len = torch.sqrt(((p_tgt[:, :, i] - p_tgt[:, :, j]) ** 2).sum(dim=-1) + 1e-8) pred_lengths.append(pred_len) tgt_lengths.append(tgt_len) pred_lengths = torch.stack(pred_lengths, dim=-1) # (B,T,E) tgt_lengths = torch.stack(tgt_lengths, dim=-1) # (B,T,E) bone_err = (pred_lengths - tgt_lengths) ** 2 if mask_btj is not None: # Edge mask is valid only when both endpoint joints are valid. edge_masks = [] for i, j in edges: edge_masks.append(mask_btj[:, :, i] * mask_btj[:, :, j]) edge_mask = torch.stack(edge_masks, dim=-1) else: edge_mask = None return masked_mean(bone_err, edge_mask) class PINNPretrainLoss(nn.Module): """ Step 1 only: physics-aware encoder pretraining loss """ def __init__( self, fps: float = 30.0, w_state_rec: float = 1.0, w_pv: float = 0.0, w_va: float = 0.0, w_aj: float = 0.0, w_jerk: float = 0.001, w_bone: float = 0.1, ): super().__init__() self.fps = fps self.w_state_rec = w_state_rec self.w_pv = w_pv self.w_va = w_va self.w_aj = w_aj self.w_jerk = w_jerk self.w_bone = w_bone def forward( self, s_pred: torch.Tensor, s_target: torch.Tensor, mask_btj: Optional[torch.Tensor] = None, ): loss_dict = {} kin = kinematic_consistency_losses( s_pred=s_pred, s_target=s_target, fps=self.fps, mask_btj=mask_btj, reduction="mean", ) loss_pv = kin["pv"] loss_va = kin["va"] loss_aj = kin["aj"] loss_rec, rec_dict = state_reconstruction_loss( s_pred=s_pred, s_target=s_target, mask_btj=mask_btj, w_p=2.0, w_v=0.5, w_a=0.25, w_j=0.1, ) loss_bone = bone_length_loss( s_pred=s_pred, s_target=s_target, mask_btj=mask_btj, ) if self.w_jerk > 0.0: loss_jerk = jerk_reg_loss( s_pred=s_pred, s_target=s_target, mask_btj=mask_btj, fps=self.fps, ) else: loss_jerk = s_pred.new_zeros(()) total = ( self.w_state_rec * loss_rec + self.w_pv * loss_pv + self.w_va * loss_va + self.w_aj * loss_aj + self.w_jerk * loss_jerk + self.w_bone * loss_bone ) loss_dict["loss_state_rec"] = loss_rec loss_dict["loss_pv"] = loss_pv loss_dict["loss_va"] = loss_va loss_dict["loss_aj"] = loss_aj loss_dict["loss_bone"] = loss_bone loss_dict["loss_jerk_reg"] = loss_jerk loss_dict["loss_total"] = total loss_dict.update({f"loss_{k}": v for k, v in rec_dict.items()}) return loss_dict