| """ |
| Minimal load + forward example for the PINN-JEPA Step 1 model. |
| |
| Run from the repository root (so the flat imports resolve): |
| python inference_example.py |
| """ |
| import json |
| import torch |
|
|
| from PINN_EncoderBody import EncoderBody |
| from PINN_PretrainModel import PINNPretrainModel |
|
|
|
|
| def build_model(config_path="config.json", weights_path=None, device="cpu"): |
| cfg = json.load(open(config_path)) |
|
|
| encoder = EncoderBody(**cfg["encoder"]) |
| model = PINNPretrainModel(encoder=encoder, fps=cfg["fps"], |
| hidden_ratio=cfg.get("hidden_ratio", 2)) |
|
|
| if weights_path is not None: |
| state = torch.load(weights_path, map_location="cpu") |
| if isinstance(state, dict) and "model" in state: |
| state = state["model"] |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if missing: |
| print(f"[warn] missing keys: {missing}") |
| if unexpected: |
| print(f"[warn] unexpected keys: {unexpected}") |
|
|
| return model.to(device).eval() |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| model = build_model(weights_path=None, device=device) |
|
|
| B, T, J = 2, 64, 17 |
| x = torch.randn(B, T, J, 12, device=device) |
|
|
| out = model(x) |
| print("token_feat:", tuple(out["token_feat"].shape)) |
| print("s_hat :", tuple(out["s_hat"].shape)) |
| print("p_hat :", tuple(out["p_hat"].shape)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|