File size: 1,631 Bytes
5cb4913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Minimal load + forward example for the PINN-JEPA Step 1 model.

Run from the repository root (so the flat imports resolve):
    python inference_example.py
"""
import json
import torch

from PINN_EncoderBody import EncoderBody
from PINN_PretrainModel import PINNPretrainModel


def build_model(config_path="config.json", weights_path=None, device="cpu"):
    cfg = json.load(open(config_path))

    encoder = EncoderBody(**cfg["encoder"])
    model = PINNPretrainModel(encoder=encoder, fps=cfg["fps"],
                              hidden_ratio=cfg.get("hidden_ratio", 2))

    if weights_path is not None:
        state = torch.load(weights_path, map_location="cpu")
        if isinstance(state, dict) and "model" in state:
            state = state["model"]
        missing, unexpected = model.load_state_dict(state, strict=False)
        if missing:
            print(f"[warn] missing keys: {missing}")
        if unexpected:
            print(f"[warn] unexpected keys: {unexpected}")

    return model.to(device).eval()


@torch.no_grad()
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # weights_path=None -> random init, just to show the I/O contract.
    model = build_model(weights_path=None, device=device)

    B, T, J = 2, 64, 17
    x = torch.randn(B, T, J, 12, device=device)   # [p, v, a, j] per joint

    out = model(x)
    print("token_feat:", tuple(out["token_feat"].shape))  # (B, T, J, D)
    print("s_hat     :", tuple(out["s_hat"].shape))        # (B, T, J, 12)
    print("p_hat     :", tuple(out["p_hat"].shape))        # (B, T, J, 3)


if __name__ == "__main__":
    main()