pinn / PINN_Lossfunction.py
SeongvinJu's picture
Upload folder using huggingface_hub
5cb4913 verified
Raw
History Blame Contribute Delete
6.35 kB
import torch
import torch.nn as nn
from typing import Optional
from Utils import return_edges, split_state, central_diff, masked_mean
H36M17_EDGES = return_edges()
def kinematic_consistency_losses(
s_pred: torch.Tensor,
s_target: torch.Tensor,
fps: float,
mask_btj: Optional[torch.Tensor] = None,
reduction: str = "mean",
eps: float = 1e-6,
):
dt = 1.0 / fps
p, v, a, j = split_state(s_pred)
_, v_t, a_t, j_t = split_state(s_target)
dp = central_diff(p, dt)
dv = central_diff(v, dt)
da = central_diff(a, dt)
pv_err = v - dp
va_err = a - dv
aj_err = j - da
pv = (pv_err ** 2).sum(dim=-1) / ((v_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
va = (va_err ** 2).sum(dim=-1) / ((a_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
aj = (aj_err ** 2).sum(dim=-1) / ((j_t.detach() ** 2).sum(dim=-1).clamp_min(eps))
if reduction == "mean":
return {
"pv": masked_mean(pv, mask_btj),
"va": masked_mean(va, mask_btj),
"aj": masked_mean(aj, mask_btj),
}
elif reduction == "none":
return {
"pv": pv,
"va": va,
"aj": aj,
}
else:
raise ValueError(f"Unknown reduction: {reduction}")
def jerk_reg_loss(
s_pred: torch.Tensor,
s_target: torch.Tensor,
mask_btj: Optional[torch.Tensor] = None,
fps: float = 30.0,
eps: float = 1e-6,
):
_, _, _, j_pred = split_state(s_pred)
_, _, _, j_t = split_state(s_target)
target_scale = (j_t ** 2).sum(dim=-1)
target_scale = masked_mean(target_scale, mask_btj).detach().clamp_min(eps)
pred_mag = (j_pred ** 2).sum(dim=-1)
pred_mag = masked_mean(pred_mag, mask_btj)
return pred_mag / target_scale
def state_reconstruction_loss(
s_pred: torch.Tensor,
s_target: torch.Tensor,
mask_btj: Optional[torch.Tensor] = None,
w_p: float = 1.0,
w_v: float = 1.0,
w_a: float = 0.5,
w_j: float = 0.25,
):
p_pred, v_pred, a_pred, j_pred = split_state(s_pred)
p_tgt, v_tgt, a_tgt, j_tgt = split_state(s_target)
p_err = ((p_pred - p_tgt) ** 2).sum(dim=-1)
v_err = ((v_pred - v_tgt) ** 2).sum(dim=-1)
a_err = ((a_pred - a_tgt) ** 2).sum(dim=-1)
j_err = ((j_pred - j_tgt) ** 2).sum(dim=-1)
loss_p = masked_mean(p_err, mask_btj)
loss_v = masked_mean(v_err, mask_btj)
loss_a = masked_mean(a_err, mask_btj)
loss_j = masked_mean(j_err, mask_btj)
total = (
w_p * loss_p +
w_v * loss_v +
w_a * loss_a +
w_j * loss_j
)
return total, {
"rec_p": loss_p,
"rec_v": loss_v,
"rec_a": loss_a,
"rec_j": loss_j,
}
def bone_length_loss(
s_pred: torch.Tensor,
s_target: torch.Tensor,
mask_btj: Optional[torch.Tensor] = None,
edges=H36M17_EDGES,
):
p_pred, _, _, _ = split_state(s_pred)
p_tgt, _, _, _ = split_state(s_target)
pred_lengths = []
tgt_lengths = []
for i, j in edges:
pred_len = torch.sqrt(((p_pred[:, :, i] - p_pred[:, :, j]) ** 2).sum(dim=-1) + 1e-8)
tgt_len = torch.sqrt(((p_tgt[:, :, i] - p_tgt[:, :, j]) ** 2).sum(dim=-1) + 1e-8)
pred_lengths.append(pred_len)
tgt_lengths.append(tgt_len)
pred_lengths = torch.stack(pred_lengths, dim=-1) # (B,T,E)
tgt_lengths = torch.stack(tgt_lengths, dim=-1) # (B,T,E)
bone_err = (pred_lengths - tgt_lengths) ** 2
if mask_btj is not None:
# Edge mask is valid only when both endpoint joints are valid.
edge_masks = []
for i, j in edges:
edge_masks.append(mask_btj[:, :, i] * mask_btj[:, :, j])
edge_mask = torch.stack(edge_masks, dim=-1)
else:
edge_mask = None
return masked_mean(bone_err, edge_mask)
class PINNPretrainLoss(nn.Module):
"""
Step 1 only:
physics-aware encoder pretraining loss
"""
def __init__(
self,
fps: float = 30.0,
w_state_rec: float = 1.0,
w_pv: float = 0.0,
w_va: float = 0.0,
w_aj: float = 0.0,
w_jerk: float = 0.001,
w_bone: float = 0.1,
):
super().__init__()
self.fps = fps
self.w_state_rec = w_state_rec
self.w_pv = w_pv
self.w_va = w_va
self.w_aj = w_aj
self.w_jerk = w_jerk
self.w_bone = w_bone
def forward(
self,
s_pred: torch.Tensor,
s_target: torch.Tensor,
mask_btj: Optional[torch.Tensor] = None,
):
loss_dict = {}
kin = kinematic_consistency_losses(
s_pred=s_pred,
s_target=s_target,
fps=self.fps,
mask_btj=mask_btj,
reduction="mean",
)
loss_pv = kin["pv"]
loss_va = kin["va"]
loss_aj = kin["aj"]
loss_rec, rec_dict = state_reconstruction_loss(
s_pred=s_pred,
s_target=s_target,
mask_btj=mask_btj,
w_p=2.0,
w_v=0.5,
w_a=0.25,
w_j=0.1,
)
loss_bone = bone_length_loss(
s_pred=s_pred,
s_target=s_target,
mask_btj=mask_btj,
)
if self.w_jerk > 0.0:
loss_jerk = jerk_reg_loss(
s_pred=s_pred,
s_target=s_target,
mask_btj=mask_btj,
fps=self.fps,
)
else:
loss_jerk = s_pred.new_zeros(())
total = (
self.w_state_rec * loss_rec +
self.w_pv * loss_pv +
self.w_va * loss_va +
self.w_aj * loss_aj +
self.w_jerk * loss_jerk +
self.w_bone * loss_bone
)
loss_dict["loss_state_rec"] = loss_rec
loss_dict["loss_pv"] = loss_pv
loss_dict["loss_va"] = loss_va
loss_dict["loss_aj"] = loss_aj
loss_dict["loss_bone"] = loss_bone
loss_dict["loss_jerk_reg"] = loss_jerk
loss_dict["loss_total"] = total
loss_dict.update({f"loss_{k}": v for k, v in rec_dict.items()})
return loss_dict