| import torch
|
| import torch.nn as nn
|
| from torch.utils.checkpoint import checkpoint
|
|
|
| from Utils import return_edges, build_adj |
|
|
| H36M17_EDGES = return_edges()
|
|
|
|
|
| class StateEmbedding(nn.Module):
|
| """
|
| Input:
|
| x: (B, T, J, 12)
|
| = [p(3), v(3), a(3), j(3)]
|
| Output:
|
| h: (B, T, J, D)
|
| """
|
| def __init__(self, d_model: int, d_state: int = 64, dropout: float = 0.0):
|
| super().__init__()
|
| self.p = nn.Linear(3, d_state)
|
| self.v = nn.Linear(3, d_state)
|
| self.a = nn.Linear(3, d_state)
|
| self.j = nn.Linear(3, d_state)
|
|
|
| self.proj = nn.Linear(4 * d_state, d_model)
|
| self.ln = nn.LayerNorm(d_model)
|
| self.drop = nn.Dropout(dropout)
|
|
|
|
|
| self.scale = nn.Parameter(torch.ones(4))
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| if x.ndim != 4 or x.shape[-1] != 12:
|
| raise ValueError(f"Expected input shape (B,T,J,12), got {tuple(x.shape)}")
|
|
|
| p, v, a, j = torch.split(x, 3, dim=-1)
|
|
|
| ep = self.p(p) * self.scale[0]
|
| ev = self.v(v) * self.scale[1]
|
| ea = self.a(a) * self.scale[2]
|
| ej = self.j(j) * self.scale[3]
|
|
|
| h = torch.cat([ep, ev, ea, ej], dim=-1)
|
| h = self.proj(h)
|
| h = self.ln(h)
|
| h = self.drop(h)
|
| return h
|
|
|
|
|
| class GraphMix(nn.Module):
|
| def __init__(self, J: int, d_model: int, edges):
|
| super().__init__()
|
| A = build_adj(J, edges)
|
| A = A / (A.sum(dim=-1, keepdim=True) + 1e-8)
|
|
|
| self.register_buffer("A", A)
|
| self.fc = nn.Linear(d_model, d_model)
|
| self.ln = nn.LayerNorm(d_model)
|
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| """
|
| h: (B, T, J, D)
|
| """
|
| msg = torch.einsum("ij,btjd->btid", self.A, h)
|
| out = self.fc(msg)
|
| return self.ln(h + out)
|
|
|
|
|
|
|
| class TemporalBlock(nn.Module):
|
| def __init__(
|
| self,
|
| d_model: int,
|
| kernel: int = 5,
|
| mlp_ratio: int = 2,
|
| dropout: float = 0.0,
|
| ):
|
| super().__init__()
|
|
|
| self.dw = nn.Conv1d(
|
| in_channels=d_model,
|
| out_channels=d_model,
|
| kernel_size=kernel,
|
| padding=kernel // 2,
|
| groups=d_model
|
| )
|
| self.pw = nn.Conv1d(
|
| in_channels=d_model,
|
| out_channels=d_model,
|
| kernel_size=1
|
| )
|
| self.ln1 = nn.LayerNorm(d_model)
|
|
|
| hidden = d_model * mlp_ratio
|
| self.ffn = nn.Sequential(
|
| nn.Linear(d_model, hidden),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden, d_model),
|
| nn.Dropout(dropout),
|
| )
|
| self.ln2 = nn.LayerNorm(d_model)
|
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| """
|
| h: (B, T, J, D)
|
| """
|
| B, T, J, D = h.shape
|
|
|
|
|
| x = h.permute(0, 2, 3, 1).contiguous().view(B * J, D, T)
|
| x = self.pw(self.dw(x))
|
| x = x.view(B, J, D, T).permute(0, 3, 1, 2).contiguous()
|
|
|
| h = self.ln1(h + x)
|
|
|
| h2 = self.ffn(h)
|
| h = self.ln2(h + h2)
|
| return h
|
|
|
|
|
| class EncoderBody(nn.Module):
|
| """
|
| Backbone body for Step 1 pretraining. |
| The same body can be reused in Step 2 JEPA. |
|
|
| Input:
|
| x: (B, T, J, 12)
|
|
|
| Output:
|
| h: (B, T, J, D)
|
| """
|
| def __init__(
|
| self,
|
| J: int = 17,
|
| d_model: int = 256,
|
| depth: int = 6,
|
| edges=H36M17_EDGES,
|
| d_state: int = 64,
|
| temporal_kernel: int = 5,
|
| mlp_ratio: int = 2,
|
| dropout: float = 0.0,
|
| use_checkpoint: bool = True,
|
| ):
|
| super().__init__()
|
|
|
| self.J = J
|
| self.d_model = d_model
|
| self.depth = depth
|
| self.use_checkpoint = use_checkpoint
|
|
|
| self.embed = StateEmbedding(
|
| d_model=d_model,
|
| d_state=d_state,
|
| dropout=dropout
|
| )
|
|
|
| self.spatial = nn.ModuleList([
|
| GraphMix(J=J, d_model=d_model, edges=edges)
|
| for _ in range(depth)
|
| ])
|
|
|
| self.temporal = nn.ModuleList([
|
| TemporalBlock(
|
| d_model=d_model,
|
| kernel=temporal_kernel,
|
| mlp_ratio=mlp_ratio,
|
| dropout=dropout
|
| )
|
| for _ in range(depth)
|
| ])
|
|
|
| self.final_ln = nn.LayerNorm(d_model)
|
|
|
| def _run_block(self, module: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
| if self.use_checkpoint and self.training:
|
| return checkpoint(module, x, use_reentrant=False)
|
| return module(x)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| x: (B, T, J, 12)
|
| return:
|
| h: (B, T, J, D)
|
| """
|
| if x.ndim != 4:
|
| raise ValueError(f"Expected x.ndim == 4, got {x.ndim}")
|
|
|
| B, T, J, C = x.shape
|
| if J != self.J:
|
| raise ValueError(f"Expected J={self.J}, got {J}")
|
| if C != 12:
|
| raise ValueError(f"Expected C=12, got {C}")
|
|
|
| h = self.embed(x)
|
|
|
| for s_block, t_block in zip(self.spatial, self.temporal):
|
| h = self._run_block(s_block, h)
|
| h = self._run_block(t_block, h)
|
|
|
| h = self.final_ln(h)
|
| return h |
|
|