pinn / inference_example.py
SeongvinJu's picture
Upload folder using huggingface_hub
5cb4913 verified
Raw
History Blame Contribute Delete
1.63 kB
"""
Minimal load + forward example for the PINN-JEPA Step 1 model.
Run from the repository root (so the flat imports resolve):
python inference_example.py
"""
import json
import torch
from PINN_EncoderBody import EncoderBody
from PINN_PretrainModel import PINNPretrainModel
def build_model(config_path="config.json", weights_path=None, device="cpu"):
cfg = json.load(open(config_path))
encoder = EncoderBody(**cfg["encoder"])
model = PINNPretrainModel(encoder=encoder, fps=cfg["fps"],
hidden_ratio=cfg.get("hidden_ratio", 2))
if weights_path is not None:
state = torch.load(weights_path, map_location="cpu")
if isinstance(state, dict) and "model" in state:
state = state["model"]
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[warn] missing keys: {missing}")
if unexpected:
print(f"[warn] unexpected keys: {unexpected}")
return model.to(device).eval()
@torch.no_grad()
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
# weights_path=None -> random init, just to show the I/O contract.
model = build_model(weights_path=None, device=device)
B, T, J = 2, 64, 17
x = torch.randn(B, T, J, 12, device=device) # [p, v, a, j] per joint
out = model(x)
print("token_feat:", tuple(out["token_feat"].shape)) # (B, T, J, D)
print("s_hat :", tuple(out["s_hat"].shape)) # (B, T, J, 12)
print("p_hat :", tuple(out["p_hat"].shape)) # (B, T, J, 3)
if __name__ == "__main__":
main()