Diffusion Policy Franka 关节空间动作模型

基于 Diffusion Policy 训练的 Franka 机器人操作策略。模型预测绝对关节角度(8维:7个关节 + 夹爪),可直接部署到真实机器人。

模型概览

任务 目录 Epoch 动作空间
抓取牛奶 checkpoints/pick_up_milk/ 49, 99 8维关节
叠杯子 checkpoints/stack_cup/ 49, 99 8维关节
网球桶翻正 checkpoints/tennis_bucket_upright/ 49, 99 8维关节

模型架构: UNet + ResNet18 Vision Encoder + DDPM

动作空间定义

  • 输入状态(agent_pos): 8维 [q1, q2, q3, q4, q5, q6, q7, gripper_joint]
  • 输入图像: 单个前置摄像头,480×640 RGB(内部 resize 到 224×224)
  • 观测步数: 2(当前帧 + 前一帧)
  • 输出动作: 15步 × 8维绝对关节角度
  • 扩散推理步数: 16(DDIM)

仓库结构

.
├── checkpoints/
│   ├── pick_up_milk/
│   │   ├── checkpoints/
│   │   │   ├── epoch=0049-val_loss=0.000.ckpt  # 50 epoch checkpoint (~3.3GB)
│   │   │   └── epoch=0099-val_loss=0.000.ckpt  # 100 epoch checkpoint (~3.3GB)
│   │   └── .hydra/
│   │       └── config.yaml                      # Hydra 配置(加载模型必须)
│   ├── stack_cup/
│   │   ├── checkpoints/
│   │   └── .hydra/
│   └── tennis_bucket_upright/
│       ├── checkpoints/
│       └── .hydra/
├── diffusion_policy/                # DP 推理代码(模型架构、normalizer 等)
├── normalizer_stats.json            # 全局归一化参数
├── normalizer_stats_pick_up_milk.json
├── normalizer_stats_stack_cup.json
├── normalizer_stats_tennis_bucket_upright.json
├── gripper_conversion.py            # 夹爪关节 → CGS 转换工具
├── example_data/
│   └── pick_up_milk/
│       ├── initial_frame.png        # 初始帧图像
│       ├── initial_joints.npy       # 初始关节角度(8维)
│       └── instruction.txt          # 任务指令
├── scripts/
│   ├── dp_policy_server_franka.py   # DP 推理服务器(Socket 通信)
│   └── rollout_with_dp_client_franka.py  # IRASim 客户端(闭环 rollout)
└── README.md

快速开始

环境安装

git clone https://huggingface.co/ewykric/dp-franka-joint
cd dp-franka-joint
pip install torch torchvision hydra-core omegaconf dill diffusers

推理服务器

# 终端1:启动 DP 服务器
export CUDA_VISIBLE_DEVICES=0
python scripts/dp_policy_server_franka.py \
    --dp_checkpoint checkpoints/pick_up_milk/checkpoints/epoch=0099-val_loss=0.000.ckpt \
    --task_name pick_up_milk \
    --port 9966

Socket 协议:

命令 请求 响应
reset_policy {instruction, initial_joints(8D)} {status}
get_action {image(H,W,3), instruction} {action(15,8), terminated}
update_obs {image(H,W,3)} {status}
update_joints {joints(8D)} {status}
close {} {status}

真机部署

# 真机闭环控制伪代码
dp_server.reset_policy("pick up the milk", initial_joint_positions)

while not done:
    image = camera.get_image()  # (H, W, 3) RGB
    actions, terminated = dp_server.get_action(image, "pick up the milk")
    # actions: (15, 8) 绝对关节角度

    for joint_cmd in actions:
        robot.move_to_joint_positions(joint_cmd[:7])
        robot.set_gripper(joint_cmd[7])
        time.sleep(1.0 / control_freq)

    # 更新 DP 的观测缓冲和关节状态
    dp_server.update_obs(camera.get_image())
    dp_server.update_joints(robot.get_joint_positions())

与 Pi0 的区别

Diffusion Policy Pi0
架构 UNet + ResNet18 PaliGemma 2B + Action Expert 300M
观测帧数 2(当前 + 前一帧) 1(仅当前帧)
语言条件 无(通过 normalizer 区分任务) 有(文本指令输入)
推理方式 DDPM 扩散去噪(16 步 DDIM) 流匹配采样
模型大小 ~3.3GB ~12GB
额外通信 需要 update_obs + update_joints 自动更新 state

训练细节

  • 架构: DiffusionUnetImagePolicy + ResNet18 视觉编码器
  • 训练数据: 每个任务 100 条真机示教轨迹
  • 训练周期: 100 epoch(保存第 49 和 99 epoch)
  • Horizon: 15(预测 15 步动作)
  • 观测步数: 2(2 帧历史)
  • 批大小: 16
  • 学习率: 1e-4(AdamW,余弦退火)
  • EMA: 启用(推理时使用 EMA 模型)
  • 归一化: 基于数据 min/max 的线性归一化到 [-1, 1]

加载模型说明

DP 的加载方式与 Pi0 不同。checkpoint 不包含模型配置,需要配合 .hydra/config.yaml 使用:

import torch, dill, hydra
from omegaconf import OmegaConf

OmegaConf.register_new_resolver("eval", eval, replace=True)

# 1. 加载 hydra config
cfg = OmegaConf.load("checkpoints/pick_up_milk/.hydra/config.yaml")

# 2. 创建模型架构
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)

# 3. 加载权重
payload = torch.load("checkpoints/pick_up_milk/checkpoints/epoch=0099-val_loss=0.000.ckpt",
                      pickle_module=dill)
workspace.ema_model.load_state_dict(payload['ema'])
policy = workspace.ema_model

# 4. 设置 normalizer(从 JSON 加载,参考 dp_policy_server_franka.py)

这就是为什么 .hydra/config.yamldiffusion_policy/ 代码都是必须上传的。

许可证

本模型基于 Diffusion Policy 框架,遵循 MIT 许可证。

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading