| import torch | |
| # Load TorchScript model | |
| model = torch.jit.load("model_logits_traced.pt") | |
| model.eval() | |
| # Dummy input (B, C, H, W) | |
| x = torch.randn(1, 32, 512, 512) | |
| with torch.no_grad(): | |
| y = model(x) | |
| print("Inference successful!") | |
| print("Output shape:", y.shape) | |
| print("Min/Max:", y.min().item(), y.max().item()) | |