pinn / PINN_EncoderBody.py
SeongvinJu's picture
Upload folder using huggingface_hub
5cb4913 verified
Raw
History Blame Contribute Delete
5.73 kB
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from Utils import return_edges, build_adj
H36M17_EDGES = return_edges()
# State Embedding
class StateEmbedding(nn.Module):
"""
Input:
x: (B, T, J, 12)
= [p(3), v(3), a(3), j(3)]
Output:
h: (B, T, J, D)
"""
def __init__(self, d_model: int, d_state: int = 64, dropout: float = 0.0):
super().__init__()
self.p = nn.Linear(3, d_state)
self.v = nn.Linear(3, d_state)
self.a = nn.Linear(3, d_state)
self.j = nn.Linear(3, d_state)
self.proj = nn.Linear(4 * d_state, d_model)
self.ln = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
# learnable scaling for each physical component
self.scale = nn.Parameter(torch.ones(4))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 4 or x.shape[-1] != 12:
raise ValueError(f"Expected input shape (B,T,J,12), got {tuple(x.shape)}")
p, v, a, j = torch.split(x, 3, dim=-1)
ep = self.p(p) * self.scale[0]
ev = self.v(v) * self.scale[1]
ea = self.a(a) * self.scale[2]
ej = self.j(j) * self.scale[3]
h = torch.cat([ep, ev, ea, ej], dim=-1)
h = self.proj(h)
h = self.ln(h)
h = self.drop(h)
return h # (B, T, J, D)
# Simple Graph Mixing (stable spatial mixing)
class GraphMix(nn.Module):
def __init__(self, J: int, d_model: int, edges):
super().__init__()
A = build_adj(J, edges) # (J, J)
A = A / (A.sum(dim=-1, keepdim=True) + 1e-8)
self.register_buffer("A", A)
self.fc = nn.Linear(d_model, d_model)
self.ln = nn.LayerNorm(d_model)
def forward(self, h: torch.Tensor) -> torch.Tensor:
"""
h: (B, T, J, D)
"""
msg = torch.einsum("ij,btjd->btid", self.A, h)
out = self.fc(msg)
return self.ln(h + out)
# Temporal Block (depthwise conv + pointwise conv + FFN)
class TemporalBlock(nn.Module):
def __init__(
self,
d_model: int,
kernel: int = 5,
mlp_ratio: int = 2,
dropout: float = 0.0,
):
super().__init__()
self.dw = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=kernel,
padding=kernel // 2,
groups=d_model
)
self.pw = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=1
)
self.ln1 = nn.LayerNorm(d_model)
hidden = d_model * mlp_ratio
self.ffn = nn.Sequential(
nn.Linear(d_model, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, d_model),
nn.Dropout(dropout),
)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, h: torch.Tensor) -> torch.Tensor:
"""
h: (B, T, J, D)
"""
B, T, J, D = h.shape
# temporal conv for each joint independently
x = h.permute(0, 2, 3, 1).contiguous().view(B * J, D, T)
x = self.pw(self.dw(x))
x = x.view(B, J, D, T).permute(0, 3, 1, 2).contiguous()
h = self.ln1(h + x)
h2 = self.ffn(h)
h = self.ln2(h + h2)
return h
# Encoder Body
class EncoderBody(nn.Module):
"""
Backbone body for Step 1 pretraining.
The same body can be reused in Step 2 JEPA.
Input:
x: (B, T, J, 12)
Output:
h: (B, T, J, D)
"""
def __init__(
self,
J: int = 17,
d_model: int = 256,
depth: int = 6,
edges=H36M17_EDGES,
d_state: int = 64,
temporal_kernel: int = 5,
mlp_ratio: int = 2,
dropout: float = 0.0,
use_checkpoint: bool = True,
):
super().__init__()
self.J = J
self.d_model = d_model
self.depth = depth
self.use_checkpoint = use_checkpoint
self.embed = StateEmbedding(
d_model=d_model,
d_state=d_state,
dropout=dropout
)
self.spatial = nn.ModuleList([
GraphMix(J=J, d_model=d_model, edges=edges)
for _ in range(depth)
])
self.temporal = nn.ModuleList([
TemporalBlock(
d_model=d_model,
kernel=temporal_kernel,
mlp_ratio=mlp_ratio,
dropout=dropout
)
for _ in range(depth)
])
self.final_ln = nn.LayerNorm(d_model)
def _run_block(self, module: nn.Module, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint and self.training:
return checkpoint(module, x, use_reentrant=False)
return module(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, T, J, 12)
return:
h: (B, T, J, D)
"""
if x.ndim != 4:
raise ValueError(f"Expected x.ndim == 4, got {x.ndim}")
B, T, J, C = x.shape
if J != self.J:
raise ValueError(f"Expected J={self.J}, got {J}")
if C != 12:
raise ValueError(f"Expected C=12, got {C}")
h = self.embed(x)
for s_block, t_block in zip(self.spatial, self.temporal):
h = self._run_block(s_block, h)
h = self._run_block(t_block, h)
h = self.final_ln(h)
return h