| | import torch |
| | import argparse |
| | from pi3.utils.basic import load_images_as_tensor, write_ply |
| | from pi3.utils.geometry import depth_edge |
| | from pi3.models.pi3 import Pi3 |
| |
|
| | if __name__ == '__main__': |
| | |
| | parser = argparse.ArgumentParser(description="Run inference with the Pi3 model.") |
| | |
| | parser.add_argument("--data_path", type=str, default='examples/parkour', |
| | help="Path to the input image directory or a video file.") |
| | parser.add_argument("--save_path", type=str, default='examples/parkour.ply', |
| | help="Path to save the output .ply file.") |
| | parser.add_argument("--interval", type=int, default=-1, |
| | help="Interval to sample image. Default: 1 for images dir, 10 for video") |
| | parser.add_argument("--ckpt", type=str, default=None, |
| | help="Path to the model checkpoint file. Default: None") |
| | parser.add_argument("--device", type=str, default='cuda', |
| | help="Device to run inference on ('cuda' or 'cpu'). Default: 'cuda'") |
| | |
| | args = parser.parse_args() |
| | if args.interval < 0: |
| | args.interval = 10 if args.data_path.endswith('.mp4') else 1 |
| | print(f'Sampling interval: {args.interval}') |
| |
|
| | |
| | |
| |
|
| | |
| | print(f"Loading model...") |
| | device = torch.device(args.device) |
| | if args.ckpt is not None: |
| | model = Pi3().to(device).eval() |
| | if args.ckpt.endswith('.safetensors'): |
| | from safetensors.torch import load_file |
| | weight = load_file(args.ckpt) |
| | else: |
| | weight = torch.load(args.ckpt, map_location=device, weights_only=False) |
| | |
| | model.load_state_dict(weight) |
| | else: |
| | model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval() |
| |
|
| | |
| | |
| | imgs = load_images_as_tensor(args.data_path, interval=args.interval).to(device) |
| |
|
| | |
| | print("Running model inference...") |
| | with torch.no_grad(): |
| | with torch.amp.autocast('cuda', dtype=torch.bfloat16): |
| | res = model(imgs[None]) |
| |
|
| | |
| | masks = torch.sigmoid(res['conf'][..., 0]) > 0.1 |
| | non_edge = ~depth_edge(res['local_points'][..., 2], rtol=0.03) |
| | masks = torch.logical_and(masks, non_edge)[0] |
| |
|
| | |
| | print(f"Saving point cloud to: {args.save_path}") |
| | write_ply(res['points'][0][masks].cpu(), imgs.permute(0, 2, 3, 1)[masks], args.save_path) |
| | print("Done.") |