| """ |
| Pi0 RTC Cube 模型推理示例(基于 realtime-vla Triton 加速) |
| |
| 在 RTX 4090 上运行,使用 2 个相机视角(camera1 主视角 + camera2 辅助视角)。 |
| 模型输出 50 步 action chunk,每步 32 维(前 6 维有效:5 关节 + 1 夹爪)。 |
| |
| 依赖: |
| pip install torch triton numpy |
| |
| 用法: |
| python infer_example.py |
| """ |
|
|
| import json |
| import pickle |
| import numpy as np |
| import torch |
| from pi0_infer import Pi0Inference |
|
|
| |
| CHECKPOINT_PATH = "converted_checkpoint.pkl" |
| NORM_STATS_PATH = "norm_stats.json" |
| NUM_VIEWS = 2 |
| CHUNK_SIZE = 50 |
| ACTION_DIM = 6 |
| STATE_DIM = 6 |
|
|
|
|
| def load_norm_stats(path): |
| """加载归一化统计量。""" |
| with open(path, 'r') as f: |
| data = json.load(f) |
| return data.get('norm_stats', data) |
|
|
|
|
| def normalize_state(state, norm_stats, target_dim=32): |
| """使用 mean/std 归一化 state,并 pad 到 target_dim。""" |
| state_mean = np.array(norm_stats["state"]["mean"]) |
| state_std = np.array(norm_stats["state"]["std"]) |
| normalized = (state - state_mean) / (state_std + 1e-6) |
| padded = np.zeros(target_dim, dtype=np.float32) |
| padded[:len(normalized)] = normalized |
| return padded |
|
|
|
|
| def unnormalize_actions(actions, norm_stats, action_dim=6): |
| """反归一化 actions(mean/std 方式)。""" |
| actions_mean = np.array(norm_stats["actions"]["mean"]) |
| actions_std = np.array(norm_stats["actions"]["std"]) |
| padded_mean = np.zeros(32, dtype=np.float32) |
| padded_std = np.zeros(32, dtype=np.float32) |
| padded_mean[:len(actions_mean)] = actions_mean |
| padded_std[:len(actions_std)] = actions_std |
| raw = actions * (padded_std + 1e-6) + padded_mean |
| return raw[:, :action_dim] |
|
|
|
|
| def normalize_image(image): |
| """将 uint8 图像归一化到 [-1, 1] 范围,resize 到 224x224。""" |
| from PIL import Image as PILImage |
| if image.shape[:2] != (224, 224): |
| pil_img = PILImage.fromarray(image) |
| cur_w, cur_h = pil_img.size |
| ratio = max(cur_w / 224, cur_h / 224) |
| new_h, new_w = int(cur_h / ratio), int(cur_w / ratio) |
| resized = pil_img.resize((new_w, new_h), resample=PILImage.BILINEAR) |
| canvas = PILImage.new(resized.mode, (224, 224), 0) |
| pad_h = max(0, (224 - new_h) // 2) |
| pad_w = max(0, (224 - new_w) // 2) |
| canvas.paste(resized, (pad_w, pad_h)) |
| image = np.array(canvas) |
| return image.astype(np.float32) / 255.0 * 2.0 - 1.0 |
|
|
|
|
| def main(): |
| |
| print("加载模型权重 ...") |
| with open(CHECKPOINT_PATH, 'rb') as f: |
| checkpoint = pickle.load(f) |
| print("初始化推理引擎 ...") |
| infer = Pi0Inference(checkpoint, num_views=NUM_VIEWS, chunk_size=CHUNK_SIZE) |
| print("模型就绪!") |
|
|
| |
| norm_stats = load_norm_stats(NORM_STATS_PATH) |
|
|
| |
| |
| camera1_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) |
| |
| camera2_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) |
| |
| joint_state = np.zeros(STATE_DIM, dtype=np.float32) |
|
|
| |
| img1 = normalize_image(camera1_image) |
| img2 = normalize_image(camera2_image) |
|
|
| observation_images = torch.from_numpy( |
| np.stack([img1, img2], axis=0) |
| ).to(torch.bfloat16).cuda() |
|
|
| observation_state = torch.from_numpy( |
| normalize_state(joint_state, norm_stats, target_dim=32) |
| ).to(torch.bfloat16).cuda() |
|
|
| diffusion_noise = torch.randn( |
| CHUNK_SIZE, 32, dtype=torch.bfloat16, device="cuda" |
| ) |
|
|
| |
| print("执行推理 ...") |
| with torch.no_grad(): |
| raw_actions = infer.forward(observation_images, observation_state, diffusion_noise) |
|
|
| |
| raw_actions_np = raw_actions.cpu().float().numpy() |
| actions = unnormalize_actions(raw_actions_np, norm_stats, ACTION_DIM) |
|
|
| print(f"\n推理结果:") |
| print(f" Action chunk shape: {actions.shape}") |
| print(f" 第一步 action: {actions[0]}") |
| print(f" Action 范围: [{actions.min():.4f}, {actions.max():.4f}]") |
|
|
| |
| print("\n性能测试 ...") |
| torch.cuda.synchronize() |
| import time |
| start = time.perf_counter() |
| n_iters = 100 |
| for _ in range(n_iters): |
| infer.forward(observation_images, observation_state, diffusion_noise) |
| torch.cuda.synchronize() |
| elapsed = (time.perf_counter() - start) / n_iters * 1000 |
| print(f" 平均推理延迟: {elapsed:.1f} ms ({1000/elapsed:.0f} FPS)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|