""" Minimal load + forward example for the PINN-JEPA Step 1 model. Run from the repository root (so the flat imports resolve): python inference_example.py """ import json import torch from PINN_EncoderBody import EncoderBody from PINN_PretrainModel import PINNPretrainModel def build_model(config_path="config.json", weights_path=None, device="cpu"): cfg = json.load(open(config_path)) encoder = EncoderBody(**cfg["encoder"]) model = PINNPretrainModel(encoder=encoder, fps=cfg["fps"], hidden_ratio=cfg.get("hidden_ratio", 2)) if weights_path is not None: state = torch.load(weights_path, map_location="cpu") if isinstance(state, dict) and "model" in state: state = state["model"] missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"[warn] missing keys: {missing}") if unexpected: print(f"[warn] unexpected keys: {unexpected}") return model.to(device).eval() @torch.no_grad() def main(): device = "cuda" if torch.cuda.is_available() else "cpu" # weights_path=None -> random init, just to show the I/O contract. model = build_model(weights_path=None, device=device) B, T, J = 2, 64, 17 x = torch.randn(B, T, J, 12, device=device) # [p, v, a, j] per joint out = model(x) print("token_feat:", tuple(out["token_feat"].shape)) # (B, T, J, D) print("s_hat :", tuple(out["s_hat"].shape)) # (B, T, J, 12) print("p_hat :", tuple(out["p_hat"].shape)) # (B, T, J, 3) if __name__ == "__main__": main()