import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from Utils import return_edges, build_adj H36M17_EDGES = return_edges() # State Embedding class StateEmbedding(nn.Module): """ Input: x: (B, T, J, 12) = [p(3), v(3), a(3), j(3)] Output: h: (B, T, J, D) """ def __init__(self, d_model: int, d_state: int = 64, dropout: float = 0.0): super().__init__() self.p = nn.Linear(3, d_state) self.v = nn.Linear(3, d_state) self.a = nn.Linear(3, d_state) self.j = nn.Linear(3, d_state) self.proj = nn.Linear(4 * d_state, d_model) self.ln = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) # learnable scaling for each physical component self.scale = nn.Parameter(torch.ones(4)) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 4 or x.shape[-1] != 12: raise ValueError(f"Expected input shape (B,T,J,12), got {tuple(x.shape)}") p, v, a, j = torch.split(x, 3, dim=-1) ep = self.p(p) * self.scale[0] ev = self.v(v) * self.scale[1] ea = self.a(a) * self.scale[2] ej = self.j(j) * self.scale[3] h = torch.cat([ep, ev, ea, ej], dim=-1) h = self.proj(h) h = self.ln(h) h = self.drop(h) return h # (B, T, J, D) # Simple Graph Mixing (stable spatial mixing) class GraphMix(nn.Module): def __init__(self, J: int, d_model: int, edges): super().__init__() A = build_adj(J, edges) # (J, J) A = A / (A.sum(dim=-1, keepdim=True) + 1e-8) self.register_buffer("A", A) self.fc = nn.Linear(d_model, d_model) self.ln = nn.LayerNorm(d_model) def forward(self, h: torch.Tensor) -> torch.Tensor: """ h: (B, T, J, D) """ msg = torch.einsum("ij,btjd->btid", self.A, h) out = self.fc(msg) return self.ln(h + out) # Temporal Block (depthwise conv + pointwise conv + FFN) class TemporalBlock(nn.Module): def __init__( self, d_model: int, kernel: int = 5, mlp_ratio: int = 2, dropout: float = 0.0, ): super().__init__() self.dw = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=kernel, padding=kernel // 2, groups=d_model ) self.pw = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=1 ) self.ln1 = nn.LayerNorm(d_model) hidden = d_model * mlp_ratio self.ffn = nn.Sequential( nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, d_model), nn.Dropout(dropout), ) self.ln2 = nn.LayerNorm(d_model) def forward(self, h: torch.Tensor) -> torch.Tensor: """ h: (B, T, J, D) """ B, T, J, D = h.shape # temporal conv for each joint independently x = h.permute(0, 2, 3, 1).contiguous().view(B * J, D, T) x = self.pw(self.dw(x)) x = x.view(B, J, D, T).permute(0, 3, 1, 2).contiguous() h = self.ln1(h + x) h2 = self.ffn(h) h = self.ln2(h + h2) return h # Encoder Body class EncoderBody(nn.Module): """ Backbone body for Step 1 pretraining. The same body can be reused in Step 2 JEPA. Input: x: (B, T, J, 12) Output: h: (B, T, J, D) """ def __init__( self, J: int = 17, d_model: int = 256, depth: int = 6, edges=H36M17_EDGES, d_state: int = 64, temporal_kernel: int = 5, mlp_ratio: int = 2, dropout: float = 0.0, use_checkpoint: bool = True, ): super().__init__() self.J = J self.d_model = d_model self.depth = depth self.use_checkpoint = use_checkpoint self.embed = StateEmbedding( d_model=d_model, d_state=d_state, dropout=dropout ) self.spatial = nn.ModuleList([ GraphMix(J=J, d_model=d_model, edges=edges) for _ in range(depth) ]) self.temporal = nn.ModuleList([ TemporalBlock( d_model=d_model, kernel=temporal_kernel, mlp_ratio=mlp_ratio, dropout=dropout ) for _ in range(depth) ]) self.final_ln = nn.LayerNorm(d_model) def _run_block(self, module: nn.Module, x: torch.Tensor) -> torch.Tensor: if self.use_checkpoint and self.training: return checkpoint(module, x, use_reentrant=False) return module(x) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, T, J, 12) return: h: (B, T, J, D) """ if x.ndim != 4: raise ValueError(f"Expected x.ndim == 4, got {x.ndim}") B, T, J, C = x.shape if J != self.J: raise ValueError(f"Expected J={self.J}, got {J}") if C != 12: raise ValueError(f"Expected C=12, got {C}") h = self.embed(x) for s_block, t_block in zip(self.spatial, self.temporal): h = self._run_block(s_block, h) h = self._run_block(t_block, h) h = self.final_ln(h) return h