Diffusion Policy Franka 关节空间动作模型
基于 Diffusion Policy 训练的 Franka 机器人操作策略。模型预测绝对关节角度(8维:7个关节 + 夹爪),可直接部署到真实机器人。
模型概览
| 任务 | 目录 | Epoch | 动作空间 |
|---|---|---|---|
| 抓取牛奶 | checkpoints/pick_up_milk/ |
49, 99 | 8维关节 |
| 叠杯子 | checkpoints/stack_cup/ |
49, 99 | 8维关节 |
| 网球桶翻正 | checkpoints/tennis_bucket_upright/ |
49, 99 | 8维关节 |
模型架构: UNet + ResNet18 Vision Encoder + DDPM
动作空间定义
- 输入状态(agent_pos): 8维
[q1, q2, q3, q4, q5, q6, q7, gripper_joint] - 输入图像: 单个前置摄像头,480×640 RGB(内部 resize 到 224×224)
- 观测步数: 2(当前帧 + 前一帧)
- 输出动作: 15步 × 8维绝对关节角度
- 扩散推理步数: 16(DDIM)
仓库结构
.
├── checkpoints/
│ ├── pick_up_milk/
│ │ ├── checkpoints/
│ │ │ ├── epoch=0049-val_loss=0.000.ckpt # 50 epoch checkpoint (~3.3GB)
│ │ │ └── epoch=0099-val_loss=0.000.ckpt # 100 epoch checkpoint (~3.3GB)
│ │ └── .hydra/
│ │ └── config.yaml # Hydra 配置(加载模型必须)
│ ├── stack_cup/
│ │ ├── checkpoints/
│ │ └── .hydra/
│ └── tennis_bucket_upright/
│ ├── checkpoints/
│ └── .hydra/
├── diffusion_policy/ # DP 推理代码(模型架构、normalizer 等)
├── normalizer_stats.json # 全局归一化参数
├── normalizer_stats_pick_up_milk.json
├── normalizer_stats_stack_cup.json
├── normalizer_stats_tennis_bucket_upright.json
├── gripper_conversion.py # 夹爪关节 → CGS 转换工具
├── example_data/
│ └── pick_up_milk/
│ ├── initial_frame.png # 初始帧图像
│ ├── initial_joints.npy # 初始关节角度(8维)
│ └── instruction.txt # 任务指令
├── scripts/
│ ├── dp_policy_server_franka.py # DP 推理服务器(Socket 通信)
│ └── rollout_with_dp_client_franka.py # IRASim 客户端(闭环 rollout)
└── README.md
快速开始
环境安装
git clone https://huggingface.co/ewykric/dp-franka-joint
cd dp-franka-joint
pip install torch torchvision hydra-core omegaconf dill diffusers
推理服务器
# 终端1:启动 DP 服务器
export CUDA_VISIBLE_DEVICES=0
python scripts/dp_policy_server_franka.py \
--dp_checkpoint checkpoints/pick_up_milk/checkpoints/epoch=0099-val_loss=0.000.ckpt \
--task_name pick_up_milk \
--port 9966
Socket 协议:
| 命令 | 请求 | 响应 |
|---|---|---|
reset_policy |
{instruction, initial_joints(8D)} |
{status} |
get_action |
{image(H,W,3), instruction} |
{action(15,8), terminated} |
update_obs |
{image(H,W,3)} |
{status} |
update_joints |
{joints(8D)} |
{status} |
close |
{} |
{status} |
真机部署
# 真机闭环控制伪代码
dp_server.reset_policy("pick up the milk", initial_joint_positions)
while not done:
image = camera.get_image() # (H, W, 3) RGB
actions, terminated = dp_server.get_action(image, "pick up the milk")
# actions: (15, 8) 绝对关节角度
for joint_cmd in actions:
robot.move_to_joint_positions(joint_cmd[:7])
robot.set_gripper(joint_cmd[7])
time.sleep(1.0 / control_freq)
# 更新 DP 的观测缓冲和关节状态
dp_server.update_obs(camera.get_image())
dp_server.update_joints(robot.get_joint_positions())
与 Pi0 的区别
| Diffusion Policy | Pi0 | |
|---|---|---|
| 架构 | UNet + ResNet18 | PaliGemma 2B + Action Expert 300M |
| 观测帧数 | 2(当前 + 前一帧) | 1(仅当前帧) |
| 语言条件 | 无(通过 normalizer 区分任务) | 有(文本指令输入) |
| 推理方式 | DDPM 扩散去噪(16 步 DDIM) | 流匹配采样 |
| 模型大小 | ~3.3GB | ~12GB |
| 额外通信 | 需要 update_obs + update_joints |
自动更新 state |
训练细节
- 架构: DiffusionUnetImagePolicy + ResNet18 视觉编码器
- 训练数据: 每个任务 100 条真机示教轨迹
- 训练周期: 100 epoch(保存第 49 和 99 epoch)
- Horizon: 15(预测 15 步动作)
- 观测步数: 2(2 帧历史)
- 批大小: 16
- 学习率: 1e-4(AdamW,余弦退火)
- EMA: 启用(推理时使用 EMA 模型)
- 归一化: 基于数据 min/max 的线性归一化到 [-1, 1]
加载模型说明
DP 的加载方式与 Pi0 不同。checkpoint 不包含模型配置,需要配合 .hydra/config.yaml 使用:
import torch, dill, hydra
from omegaconf import OmegaConf
OmegaConf.register_new_resolver("eval", eval, replace=True)
# 1. 加载 hydra config
cfg = OmegaConf.load("checkpoints/pick_up_milk/.hydra/config.yaml")
# 2. 创建模型架构
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
# 3. 加载权重
payload = torch.load("checkpoints/pick_up_milk/checkpoints/epoch=0099-val_loss=0.000.ckpt",
pickle_module=dill)
workspace.ema_model.load_state_dict(payload['ema'])
policy = workspace.ema_model
# 4. 设置 normalizer(从 JSON 加载,参考 dp_policy_server_franka.py)
这就是为什么 .hydra/config.yaml 和 diffusion_policy/ 代码都是必须上传的。
许可证
本模型基于 Diffusion Policy 框架,遵循 MIT 许可证。