dp-franka-joint / scripts /dp_policy_server_franka.py
ewykric's picture
Upload folder using huggingface_hub
33c751d verified
#!/usr/bin/env python3
"""
Diffusion Policy (DP) Server for Franka IRASim Integration
This server runs DP policy inference for Franka robot and communicates with IRASim client via sockets.
Uses the native diffusion_policy interface from fanka_dp_ckpt.
Key features:
- Single arm (8 DOF: 7 joints + 1 gripper)
- Single camera view
- Gripper dimension converted from joint space to CGS (continuous_gripper_state)
"""
import os
import sys
# CRITICAL: Set up path BEFORE any other imports
# This ensures fanka_dp_ckpt's diffusion_policy is used, not RoboTwin's
FANKA_DP_PATH = "/mnt/wyk/IRASim/fanka_dp_ckpt"
sys.path.insert(0, FANKA_DP_PATH)
# Remove any cached diffusion_policy modules to force reimport from fanka_dp_ckpt
modules_to_remove = [key for key in sys.modules.keys() if 'diffusion_policy' in key]
for mod in modules_to_remove:
del sys.modules[mod]
import socket
import pickle
import argparse
import numpy as np
import cv2
from collections import deque
from pathlib import Path
# DP imports from fanka_dp_ckpt native interface
import torch
import hydra
import dill
from omegaconf import OmegaConf
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
import json
# Import gripper conversion utility
sys.path.insert(0, "/mnt/wyk/IRASim")
from gripper_conversion import joint_to_cgs
# Register eval resolver for OmegaConf
OmegaConf.register_new_resolver("eval", eval, replace=True)
def load_dp_policy(checkpoint_path, device, task_name=None):
"""
Load Diffusion Policy from checkpoint using native interface.
Checkpoint format (from train_diffusion_unet_franka_real_hybrid_workspace):
- model: model state dict
- ema: EMA model state dict
- optimizer: optimizer state dict
- lr_scheduler: scheduler state dict
"""
print(f"Loading DP checkpoint from: {checkpoint_path}")
# Load checkpoint
payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill)
print(f"Checkpoint keys: {list(payload.keys())}")
# Load config from .hydra/config.yaml (not stored in checkpoint)
checkpoint_dir = Path(checkpoint_path).parent.parent
config_path = checkpoint_dir / ".hydra" / "config.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
cfg = OmegaConf.load(config_path)
print(f"Loaded config from: {config_path}")
print(f"Config _target_: {cfg._target_}")
print(f" n_obs_steps: {cfg.n_obs_steps}")
print(f" n_action_steps: {cfg.n_action_steps}")
print(f" horizon: {cfg.horizon}")
# Create workspace using hydra (this creates model architecture)
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
print(f"Workspace created: {type(workspace).__name__}")
# Manually load model weights (not using load_payload which expects different format)
if 'model' in payload:
workspace.model.load_state_dict(payload['model'])
print("Loaded model weights")
if 'ema' in payload and cfg.training.use_ema:
workspace.ema_model.load_state_dict(payload['ema'])
print("Loaded EMA model weights")
# Get policy (use EMA if available)
policy: BaseImagePolicy
if 'diffusion' in cfg.name:
policy = workspace.model
if cfg.training.use_ema and workspace.ema_model is not None:
policy = workspace.ema_model
print("Using EMA model for inference")
else:
policy = workspace.model
# Set up normalizer from pre-computed JSON (FAST! No dataset loading needed)
# Use task-specific normalizer if task_name is provided
if task_name:
normalizer_stats_path = f"/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats_{task_name}.json"
if not os.path.exists(normalizer_stats_path):
print(f"Warning: Task-specific normalizer not found: {normalizer_stats_path}")
normalizer_stats_path = "/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats.json"
else:
normalizer_stats_path = "/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats.json"
print(f"Loading normalizer stats from: {normalizer_stats_path}")
with open(normalizer_stats_path, 'r') as f:
stats = json.load(f)
# Create LinearNormalizer and populate it manually
normalizer = LinearNormalizer()
for key in ['action', 'agent_pos']:
input_min = torch.tensor(stats[key]['min'], dtype=torch.float32)
input_max = torch.tensor(stats[key]['max'], dtype=torch.float32)
# Compute scale and offset using limits mode formula (from _fit function)
# scale = (output_max - output_min) / input_range = 2.0 / (max - min)
# offset = output_min - scale * input_min = -1.0 - scale * min
input_range = input_max - input_min
scale = 2.0 / input_range
offset = -1.0 - scale * input_min
# Compute mean and std for input_stats (approximate)
input_mean = (input_max + input_min) / 2.0
input_std = input_range / 4.0 # Approximate std for uniform distribution
input_stats_dict = {
'min': input_min,
'max': input_max,
'mean': input_mean,
'std': input_std
}
field_normalizer = SingleFieldLinearNormalizer.create_manual(
scale=scale,
offset=offset,
input_stats_dict=input_stats_dict
)
normalizer[key] = field_normalizer
# Add image normalizer (images are in [0,1], normalized to [-1,1])
image_scale = np.array([2], dtype=np.float32)
image_offset = np.array([-1], dtype=np.float32)
image_stats = {
'min': np.array([0], dtype=np.float32),
'max': np.array([1], dtype=np.float32),
'mean': np.array([0.5], dtype=np.float32),
'std': np.array([np.sqrt(1/12)], dtype=np.float32)
}
normalizer['image'] = SingleFieldLinearNormalizer.create_manual(
scale=image_scale,
offset=image_offset,
input_stats_dict=image_stats
)
policy.set_normalizer(normalizer)
print("Normalizer loaded from JSON (instant!)")
# Print normalizer stats for debugging
print("Normalizer stats:")
for key in ['action', 'agent_pos']:
if key in normalizer.params_dict:
stats = normalizer.params_dict[key]['input_stats']
print(f" {key}:")
print(f" min: {stats['min'].detach().numpy()}")
print(f" max: {stats['max'].detach().numpy()}")
# Move to device and set eval mode
policy.eval().to(device)
# Set inference parameters
policy.num_inference_steps = 16 # DDIM inference iterations
# NOTE: Do NOT override n_action_steps! Use value from config (15)
# policy.n_action_steps is already set from config during model creation
print(f"DP policy loaded successfully")
print(f" - Device: {device}")
print(f" - Policy type: {type(policy).__name__}")
print(f" - Policy horizon: {policy.horizon}")
print(f" - Policy n_action_steps: {policy.n_action_steps}")
print(f" - Policy n_obs_steps: {policy.n_obs_steps}")
print(f" - Policy action_dim: {policy.action_dim}")
print(f" - Policy num_inference_steps: {policy.num_inference_steps}")
return policy, cfg
class DPPolicyServer:
"""Diffusion Policy server for Franka robot using native interface"""
def __init__(self, checkpoint_path, port=9966, device='cuda', task_name=None):
self.port = port
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.task_name = task_name
# Load policy and config using native interface
self.policy, self.cfg = load_dp_policy(checkpoint_path, self.device, task_name=task_name)
# Get config parameters
self.n_obs_steps = self.cfg.n_obs_steps
self.n_action_steps = self.policy.n_action_steps
self.horizon = self.policy.horizon
# Get expected image shape from config
self.image_shape = self.cfg.task.shape_meta.obs.image.shape # [3, H, W]
self.image_height = self.image_shape[1]
self.image_width = self.image_shape[2]
print(f" - Config n_obs_steps: {self.n_obs_steps}")
print(f" - Policy n_action_steps: {self.n_action_steps}")
print(f" - Policy horizon: {self.horizon}")
print(f" - Expected image shape: {self.image_shape}")
# Initialize observation buffer (stores n_obs_steps frames)
self.obs_buffer = deque(maxlen=self.n_obs_steps)
# Joint state tracking
self.current_joints = None # (8,) current joint state
print(f"DP server initialized")
print(f" - Port: {port}")
print(f" - Observation steps: {self.n_obs_steps}")
print(f" - Action steps: {self.n_action_steps}")
print(f" - Joint dimensions: 8 (7 joints + 1 gripper)")
def reset(self, instruction, initial_joints):
"""Reset policy state with initial joint configuration"""
print(f"\n[SERVER] Resetting policy...")
print(f" - Instruction: {instruction}")
print(f" - Initial joints: {initial_joints}")
# Convert to numpy
if isinstance(initial_joints, list):
initial_joints = np.array(initial_joints, dtype=np.float32)
self.current_joints = initial_joints
self.obs_buffer.clear()
# Reset policy internal state if available
if hasattr(self.policy, 'reset'):
self.policy.reset()
print(f"Policy reset complete")
return {'status': 'success'}
def get_action(self, image, instruction):
"""
Get action from policy given current observation
Args:
image: numpy array (H, W, 3) RGB image (can be any size, will be resized)
instruction: str (not used by DP but kept for API consistency)
Returns:
actions: (n_action_steps, 8) numpy array of future joint positions
"""
print(f"\n[SERVER] Getting action...")
print(f" - Image shape: {image.shape}")
print(f" - Current joints: {self.current_joints}")
print(f" - Obs buffer size: {len(self.obs_buffer)}")
# Resize image to expected size if needed
target_h, target_w = self.image_height, self.image_width
if image.shape[:2] != (target_h, target_w):
image_resized = cv2.resize(image, (target_w, target_h)) # cv2.resize takes (W, H)
print(f" - Resized from {image.shape} to {image_resized.shape}")
else:
image_resized = image
print(f" - Image already correct size: {image.shape}")
# Add current image to observation buffer
self.obs_buffer.append(image_resized)
# If buffer not full, pad with current frame
while len(self.obs_buffer) < self.n_obs_steps:
self.obs_buffer.append(image_resized)
# Prepare observation dictionary for DP
obs_dict = self._prepare_observation()
# Run policy inference
with torch.no_grad():
# DP expects batch dimension
obs_dict_input = dict_apply(obs_dict,
lambda x: x.unsqueeze(0).to(self.device) if isinstance(x, torch.Tensor) else x)
# Get action prediction using native predict_action
result = self.policy.predict_action(obs_dict_input)
action = result['action'][0].cpu().numpy() # (n_action_steps, 8)
print(f" - Predicted actions shape: {action.shape}")
print(f" - Action range: [{action.min():.4f}, {action.max():.4f}]")
print(f" - Action[0]: {action[0]}")
# Convert gripper dimension from joint space to CGS
action_with_cgs = self._convert_gripper_to_cgs(action)
print(f" - Action (with CGS) first: {action_with_cgs[0]}")
print(f" - Action (with CGS) last: {action_with_cgs[-1]}")
return action_with_cgs
def update_joints(self, new_joints):
"""Update current joint state (called after executing actions)"""
if isinstance(new_joints, list):
new_joints = np.array(new_joints, dtype=np.float32)
self.current_joints = new_joints
print(f"[SERVER] Updated joints: {self.current_joints}")
def update_obs(self, image):
"""Update observation buffer with intermediate frame (for temporal continuity)"""
# Resize to expected size
target_h, target_w = self.image_height, self.image_width
if image.shape[:2] != (target_h, target_w):
image_resized = cv2.resize(image, (target_w, target_h))
else:
image_resized = image
self.obs_buffer.append(image_resized)
print(f"[SERVER] Updated obs buffer, size: {len(self.obs_buffer)}")
def _prepare_observation(self):
"""Prepare observation dictionary for DP input"""
# Stack images: (n_obs_steps, H, W, 3)
images = np.stack(list(self.obs_buffer), axis=0)
# Convert to torch tensor and normalize to [0, 1]
# DP expects (n_obs_steps, 3, H, W) in range [0, 1]
images_torch = torch.from_numpy(images).float() / 255.0 # (T, H, W, 3)
images_torch = images_torch.permute(0, 3, 1, 2) # (T, 3, H, W)
# Prepare agent_pos (joint state)
# Repeat current joint for all observation steps
agent_pos = torch.from_numpy(self.current_joints).float() # (8,)
agent_pos = agent_pos.unsqueeze(0).repeat(self.n_obs_steps, 1) # (T, 8)
obs_dict = {
'image': images_torch, # (T, 3, H, W)
'agent_pos': agent_pos, # (T, 8)
}
return obs_dict
def _convert_gripper_to_cgs(self, actions):
"""
Convert gripper dimension from joint space to CGS
Args:
actions: (n_action_steps, 8) numpy array
Last dimension is gripper in joint space
Returns:
actions_cgs: (n_action_steps, 8) numpy array
Last dimension converted to CGS
"""
actions_cgs = actions.copy()
# Convert last dimension (gripper)
gripper_joints = actions[:, 7] # (n_action_steps,)
gripper_cgs = joint_to_cgs(gripper_joints, method='linear_fit')
actions_cgs[:, 7] = gripper_cgs
print(f"[SERVER] Gripper conversion:")
print(f" - Joint range: [{gripper_joints.min():.4f}, {gripper_joints.max():.4f}]")
print(f" - CGS range: [{gripper_cgs.min():.4f}, {gripper_cgs.max():.4f}]")
return actions_cgs
def run(self):
"""Run the server socket loop"""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('0.0.0.0', self.port))
server_socket.listen(1)
print(f"\n{'='*70}")
print(f"DP Server listening on port {self.port}")
print(f"{'='*70}")
print("Waiting for client connection...")
client_socket, client_address = server_socket.accept()
print(f"Client connected from {client_address}")
try:
while True:
# Receive request
request = self._receive_request(client_socket)
if request is None:
break
command = request['command']
print(f"\n[SERVER] Received command: {command}")
# Process command
if command == 'reset_policy':
response = self.reset(
request['instruction'],
request['initial_joints']
)
elif command == 'get_action':
# Deserialize image
image_data = request['image']
image = np.frombuffer(
image_data['data'],
dtype=image_data['dtype']
).reshape(image_data['shape'])
action = self.get_action(image, request.get('instruction', ''))
response = {
'action': action.tolist(),
'terminated': False
}
elif command == 'update_joints':
self.update_joints(request['joints'])
response = {'status': 'success'}
elif command == 'update_obs':
# Deserialize image
image_data = request['image']
image = np.frombuffer(
image_data['data'],
dtype=image_data['dtype']
).reshape(image_data['shape'])
self.update_obs(image)
response = {'status': 'success'}
elif command == 'close':
print("[SERVER] Received close command")
response = {'status': 'closing'}
self._send_response(client_socket, response)
break
else:
response = {'error': f'Unknown command: {command}'}
# Send response
self._send_response(client_socket, response)
except KeyboardInterrupt:
print("\n\nServer interrupted by user")
except Exception as e:
print(f"\n\nError: {e}")
import traceback
traceback.print_exc()
finally:
client_socket.close()
server_socket.close()
print("Server closed")
def _receive_request(self, client_socket):
"""Receive request from client"""
try:
# Receive size
size_data = self._recv_exactly(client_socket, 4)
if not size_data:
return None
data_size = int.from_bytes(size_data, byteorder='big')
# Receive data
data = self._recv_exactly(client_socket, data_size)
if len(data) != data_size:
raise ConnectionError(f"Expected {data_size} bytes, got {len(data)}")
request = pickle.loads(data)
return request
except Exception as e:
print(f"Error receiving request: {e}")
return None
def _send_response(self, client_socket, response):
"""Send response to client"""
data = pickle.dumps(response)
size = len(data)
client_socket.sendall(size.to_bytes(4, byteorder='big'))
client_socket.sendall(data)
def _recv_exactly(self, sock, num_bytes):
"""Receive exactly num_bytes from socket"""
data = b''
while len(data) < num_bytes:
packet = sock.recv(num_bytes - len(data))
if not packet:
break
data += packet
return data
def main():
parser = argparse.ArgumentParser(description='Franka DP Policy Server (Native Interface)')
parser.add_argument('--dp_checkpoint', type=str, required=True,
help='Path to DP checkpoint')
parser.add_argument('--port', type=int, default=9966,
help='Server port')
parser.add_argument('--device', type=str, default='cuda',
help='Device for inference')
parser.add_argument('--task_name', type=str, default=None,
help='Task name for task-specific normalizer (e.g., tennis_bucket_upright)')
args = parser.parse_args()
print("="*70)
print("Franka Diffusion Policy Server (Native Interface)")
print("="*70)
print(f"Checkpoint: {args.dp_checkpoint}")
print(f"Task: {args.task_name or 'all_tasks (global normalizer)'}")
print(f"Port: {args.port}")
print(f"Device: {args.device}")
print("="*70)
# Create and run server
server = DPPolicyServer(
checkpoint_path=args.dp_checkpoint,
port=args.port,
device=args.device,
task_name=args.task_name
)
server.run()
if __name__ == '__main__':
main()