| |
| """ |
| Diffusion Policy (DP) Server for Franka IRASim Integration |
| |
| This server runs DP policy inference for Franka robot and communicates with IRASim client via sockets. |
| Uses the native diffusion_policy interface from fanka_dp_ckpt. |
| |
| Key features: |
| - Single arm (8 DOF: 7 joints + 1 gripper) |
| - Single camera view |
| - Gripper dimension converted from joint space to CGS (continuous_gripper_state) |
| """ |
|
|
| import os |
| import sys |
|
|
| |
| |
| FANKA_DP_PATH = "/mnt/wyk/IRASim/fanka_dp_ckpt" |
| sys.path.insert(0, FANKA_DP_PATH) |
|
|
| |
| modules_to_remove = [key for key in sys.modules.keys() if 'diffusion_policy' in key] |
| for mod in modules_to_remove: |
| del sys.modules[mod] |
|
|
| import socket |
| import pickle |
| import argparse |
| import numpy as np |
| import cv2 |
| from collections import deque |
| from pathlib import Path |
|
|
| |
| import torch |
| import hydra |
| import dill |
| from omegaconf import OmegaConf |
| from diffusion_policy.policy.base_image_policy import BaseImagePolicy |
| from diffusion_policy.common.pytorch_util import dict_apply |
| from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer |
| import json |
|
|
| |
| sys.path.insert(0, "/mnt/wyk/IRASim") |
| from gripper_conversion import joint_to_cgs |
|
|
| |
| OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
|
|
| def load_dp_policy(checkpoint_path, device, task_name=None): |
| """ |
| Load Diffusion Policy from checkpoint using native interface. |
| |
| Checkpoint format (from train_diffusion_unet_franka_real_hybrid_workspace): |
| - model: model state dict |
| - ema: EMA model state dict |
| - optimizer: optimizer state dict |
| - lr_scheduler: scheduler state dict |
| """ |
| print(f"Loading DP checkpoint from: {checkpoint_path}") |
|
|
| |
| payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill) |
| print(f"Checkpoint keys: {list(payload.keys())}") |
|
|
| |
| checkpoint_dir = Path(checkpoint_path).parent.parent |
| config_path = checkpoint_dir / ".hydra" / "config.yaml" |
| if not config_path.exists(): |
| raise FileNotFoundError(f"Config file not found: {config_path}") |
| cfg = OmegaConf.load(config_path) |
| print(f"Loaded config from: {config_path}") |
|
|
| print(f"Config _target_: {cfg._target_}") |
| print(f" n_obs_steps: {cfg.n_obs_steps}") |
| print(f" n_action_steps: {cfg.n_action_steps}") |
| print(f" horizon: {cfg.horizon}") |
|
|
| |
| cls = hydra.utils.get_class(cfg._target_) |
| workspace = cls(cfg) |
| print(f"Workspace created: {type(workspace).__name__}") |
|
|
| |
| if 'model' in payload: |
| workspace.model.load_state_dict(payload['model']) |
| print("Loaded model weights") |
|
|
| if 'ema' in payload and cfg.training.use_ema: |
| workspace.ema_model.load_state_dict(payload['ema']) |
| print("Loaded EMA model weights") |
|
|
| |
| policy: BaseImagePolicy |
| if 'diffusion' in cfg.name: |
| policy = workspace.model |
| if cfg.training.use_ema and workspace.ema_model is not None: |
| policy = workspace.ema_model |
| print("Using EMA model for inference") |
| else: |
| policy = workspace.model |
|
|
| |
| |
| if task_name: |
| normalizer_stats_path = f"/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats_{task_name}.json" |
| if not os.path.exists(normalizer_stats_path): |
| print(f"Warning: Task-specific normalizer not found: {normalizer_stats_path}") |
| normalizer_stats_path = "/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats.json" |
| else: |
| normalizer_stats_path = "/mnt/wyk/IRASim/fanka_dp_ckpt/normalizer_stats.json" |
| print(f"Loading normalizer stats from: {normalizer_stats_path}") |
|
|
| with open(normalizer_stats_path, 'r') as f: |
| stats = json.load(f) |
|
|
| |
| normalizer = LinearNormalizer() |
|
|
| for key in ['action', 'agent_pos']: |
| input_min = torch.tensor(stats[key]['min'], dtype=torch.float32) |
| input_max = torch.tensor(stats[key]['max'], dtype=torch.float32) |
|
|
| |
| |
| |
| input_range = input_max - input_min |
| scale = 2.0 / input_range |
| offset = -1.0 - scale * input_min |
|
|
| |
| input_mean = (input_max + input_min) / 2.0 |
| input_std = input_range / 4.0 |
|
|
| input_stats_dict = { |
| 'min': input_min, |
| 'max': input_max, |
| 'mean': input_mean, |
| 'std': input_std |
| } |
|
|
| field_normalizer = SingleFieldLinearNormalizer.create_manual( |
| scale=scale, |
| offset=offset, |
| input_stats_dict=input_stats_dict |
| ) |
| normalizer[key] = field_normalizer |
|
|
| |
| image_scale = np.array([2], dtype=np.float32) |
| image_offset = np.array([-1], dtype=np.float32) |
| image_stats = { |
| 'min': np.array([0], dtype=np.float32), |
| 'max': np.array([1], dtype=np.float32), |
| 'mean': np.array([0.5], dtype=np.float32), |
| 'std': np.array([np.sqrt(1/12)], dtype=np.float32) |
| } |
| normalizer['image'] = SingleFieldLinearNormalizer.create_manual( |
| scale=image_scale, |
| offset=image_offset, |
| input_stats_dict=image_stats |
| ) |
|
|
| policy.set_normalizer(normalizer) |
| print("Normalizer loaded from JSON (instant!)") |
|
|
| |
| print("Normalizer stats:") |
| for key in ['action', 'agent_pos']: |
| if key in normalizer.params_dict: |
| stats = normalizer.params_dict[key]['input_stats'] |
| print(f" {key}:") |
| print(f" min: {stats['min'].detach().numpy()}") |
| print(f" max: {stats['max'].detach().numpy()}") |
|
|
| |
| policy.eval().to(device) |
|
|
| |
| policy.num_inference_steps = 16 |
| |
| |
|
|
| print(f"DP policy loaded successfully") |
| print(f" - Device: {device}") |
| print(f" - Policy type: {type(policy).__name__}") |
| print(f" - Policy horizon: {policy.horizon}") |
| print(f" - Policy n_action_steps: {policy.n_action_steps}") |
| print(f" - Policy n_obs_steps: {policy.n_obs_steps}") |
| print(f" - Policy action_dim: {policy.action_dim}") |
| print(f" - Policy num_inference_steps: {policy.num_inference_steps}") |
|
|
| return policy, cfg |
|
|
|
|
| class DPPolicyServer: |
| """Diffusion Policy server for Franka robot using native interface""" |
|
|
| def __init__(self, checkpoint_path, port=9966, device='cuda', task_name=None): |
| self.port = port |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
| self.task_name = task_name |
|
|
| |
| self.policy, self.cfg = load_dp_policy(checkpoint_path, self.device, task_name=task_name) |
|
|
| |
| self.n_obs_steps = self.cfg.n_obs_steps |
| self.n_action_steps = self.policy.n_action_steps |
| self.horizon = self.policy.horizon |
|
|
| |
| self.image_shape = self.cfg.task.shape_meta.obs.image.shape |
| self.image_height = self.image_shape[1] |
| self.image_width = self.image_shape[2] |
|
|
| print(f" - Config n_obs_steps: {self.n_obs_steps}") |
| print(f" - Policy n_action_steps: {self.n_action_steps}") |
| print(f" - Policy horizon: {self.horizon}") |
| print(f" - Expected image shape: {self.image_shape}") |
|
|
| |
| self.obs_buffer = deque(maxlen=self.n_obs_steps) |
|
|
| |
| self.current_joints = None |
|
|
| print(f"DP server initialized") |
| print(f" - Port: {port}") |
| print(f" - Observation steps: {self.n_obs_steps}") |
| print(f" - Action steps: {self.n_action_steps}") |
| print(f" - Joint dimensions: 8 (7 joints + 1 gripper)") |
|
|
| def reset(self, instruction, initial_joints): |
| """Reset policy state with initial joint configuration""" |
| print(f"\n[SERVER] Resetting policy...") |
| print(f" - Instruction: {instruction}") |
| print(f" - Initial joints: {initial_joints}") |
|
|
| |
| if isinstance(initial_joints, list): |
| initial_joints = np.array(initial_joints, dtype=np.float32) |
|
|
| self.current_joints = initial_joints |
| self.obs_buffer.clear() |
|
|
| |
| if hasattr(self.policy, 'reset'): |
| self.policy.reset() |
|
|
| print(f"Policy reset complete") |
| return {'status': 'success'} |
|
|
| def get_action(self, image, instruction): |
| """ |
| Get action from policy given current observation |
| |
| Args: |
| image: numpy array (H, W, 3) RGB image (can be any size, will be resized) |
| instruction: str (not used by DP but kept for API consistency) |
| |
| Returns: |
| actions: (n_action_steps, 8) numpy array of future joint positions |
| """ |
| print(f"\n[SERVER] Getting action...") |
| print(f" - Image shape: {image.shape}") |
| print(f" - Current joints: {self.current_joints}") |
| print(f" - Obs buffer size: {len(self.obs_buffer)}") |
|
|
| |
| target_h, target_w = self.image_height, self.image_width |
| if image.shape[:2] != (target_h, target_w): |
| image_resized = cv2.resize(image, (target_w, target_h)) |
| print(f" - Resized from {image.shape} to {image_resized.shape}") |
| else: |
| image_resized = image |
| print(f" - Image already correct size: {image.shape}") |
|
|
| |
| self.obs_buffer.append(image_resized) |
|
|
| |
| while len(self.obs_buffer) < self.n_obs_steps: |
| self.obs_buffer.append(image_resized) |
|
|
| |
| obs_dict = self._prepare_observation() |
|
|
| |
| with torch.no_grad(): |
| |
| obs_dict_input = dict_apply(obs_dict, |
| lambda x: x.unsqueeze(0).to(self.device) if isinstance(x, torch.Tensor) else x) |
|
|
| |
| result = self.policy.predict_action(obs_dict_input) |
| action = result['action'][0].cpu().numpy() |
|
|
| print(f" - Predicted actions shape: {action.shape}") |
| print(f" - Action range: [{action.min():.4f}, {action.max():.4f}]") |
| print(f" - Action[0]: {action[0]}") |
|
|
| |
| action_with_cgs = self._convert_gripper_to_cgs(action) |
|
|
| print(f" - Action (with CGS) first: {action_with_cgs[0]}") |
| print(f" - Action (with CGS) last: {action_with_cgs[-1]}") |
|
|
| return action_with_cgs |
|
|
| def update_joints(self, new_joints): |
| """Update current joint state (called after executing actions)""" |
| if isinstance(new_joints, list): |
| new_joints = np.array(new_joints, dtype=np.float32) |
|
|
| self.current_joints = new_joints |
| print(f"[SERVER] Updated joints: {self.current_joints}") |
|
|
| def update_obs(self, image): |
| """Update observation buffer with intermediate frame (for temporal continuity)""" |
| |
| target_h, target_w = self.image_height, self.image_width |
| if image.shape[:2] != (target_h, target_w): |
| image_resized = cv2.resize(image, (target_w, target_h)) |
| else: |
| image_resized = image |
| self.obs_buffer.append(image_resized) |
| print(f"[SERVER] Updated obs buffer, size: {len(self.obs_buffer)}") |
|
|
| def _prepare_observation(self): |
| """Prepare observation dictionary for DP input""" |
| |
| images = np.stack(list(self.obs_buffer), axis=0) |
|
|
| |
| |
| images_torch = torch.from_numpy(images).float() / 255.0 |
| images_torch = images_torch.permute(0, 3, 1, 2) |
|
|
| |
| |
| agent_pos = torch.from_numpy(self.current_joints).float() |
| agent_pos = agent_pos.unsqueeze(0).repeat(self.n_obs_steps, 1) |
|
|
| obs_dict = { |
| 'image': images_torch, |
| 'agent_pos': agent_pos, |
| } |
|
|
| return obs_dict |
|
|
| def _convert_gripper_to_cgs(self, actions): |
| """ |
| Convert gripper dimension from joint space to CGS |
| |
| Args: |
| actions: (n_action_steps, 8) numpy array |
| Last dimension is gripper in joint space |
| |
| Returns: |
| actions_cgs: (n_action_steps, 8) numpy array |
| Last dimension converted to CGS |
| """ |
| actions_cgs = actions.copy() |
|
|
| |
| gripper_joints = actions[:, 7] |
| gripper_cgs = joint_to_cgs(gripper_joints, method='linear_fit') |
| actions_cgs[:, 7] = gripper_cgs |
|
|
| print(f"[SERVER] Gripper conversion:") |
| print(f" - Joint range: [{gripper_joints.min():.4f}, {gripper_joints.max():.4f}]") |
| print(f" - CGS range: [{gripper_cgs.min():.4f}, {gripper_cgs.max():.4f}]") |
|
|
| return actions_cgs |
|
|
| def run(self): |
| """Run the server socket loop""" |
| server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| server_socket.bind(('0.0.0.0', self.port)) |
| server_socket.listen(1) |
|
|
| print(f"\n{'='*70}") |
| print(f"DP Server listening on port {self.port}") |
| print(f"{'='*70}") |
| print("Waiting for client connection...") |
|
|
| client_socket, client_address = server_socket.accept() |
| print(f"Client connected from {client_address}") |
|
|
| try: |
| while True: |
| |
| request = self._receive_request(client_socket) |
| if request is None: |
| break |
|
|
| command = request['command'] |
| print(f"\n[SERVER] Received command: {command}") |
|
|
| |
| if command == 'reset_policy': |
| response = self.reset( |
| request['instruction'], |
| request['initial_joints'] |
| ) |
|
|
| elif command == 'get_action': |
| |
| image_data = request['image'] |
| image = np.frombuffer( |
| image_data['data'], |
| dtype=image_data['dtype'] |
| ).reshape(image_data['shape']) |
|
|
| action = self.get_action(image, request.get('instruction', '')) |
|
|
| response = { |
| 'action': action.tolist(), |
| 'terminated': False |
| } |
|
|
| elif command == 'update_joints': |
| self.update_joints(request['joints']) |
| response = {'status': 'success'} |
|
|
| elif command == 'update_obs': |
| |
| image_data = request['image'] |
| image = np.frombuffer( |
| image_data['data'], |
| dtype=image_data['dtype'] |
| ).reshape(image_data['shape']) |
|
|
| self.update_obs(image) |
| response = {'status': 'success'} |
|
|
| elif command == 'close': |
| print("[SERVER] Received close command") |
| response = {'status': 'closing'} |
| self._send_response(client_socket, response) |
| break |
|
|
| else: |
| response = {'error': f'Unknown command: {command}'} |
|
|
| |
| self._send_response(client_socket, response) |
|
|
| except KeyboardInterrupt: |
| print("\n\nServer interrupted by user") |
| except Exception as e: |
| print(f"\n\nError: {e}") |
| import traceback |
| traceback.print_exc() |
| finally: |
| client_socket.close() |
| server_socket.close() |
| print("Server closed") |
|
|
| def _receive_request(self, client_socket): |
| """Receive request from client""" |
| try: |
| |
| size_data = self._recv_exactly(client_socket, 4) |
| if not size_data: |
| return None |
|
|
| data_size = int.from_bytes(size_data, byteorder='big') |
|
|
| |
| data = self._recv_exactly(client_socket, data_size) |
| if len(data) != data_size: |
| raise ConnectionError(f"Expected {data_size} bytes, got {len(data)}") |
|
|
| request = pickle.loads(data) |
| return request |
|
|
| except Exception as e: |
| print(f"Error receiving request: {e}") |
| return None |
|
|
| def _send_response(self, client_socket, response): |
| """Send response to client""" |
| data = pickle.dumps(response) |
| size = len(data) |
|
|
| client_socket.sendall(size.to_bytes(4, byteorder='big')) |
| client_socket.sendall(data) |
|
|
| def _recv_exactly(self, sock, num_bytes): |
| """Receive exactly num_bytes from socket""" |
| data = b'' |
| while len(data) < num_bytes: |
| packet = sock.recv(num_bytes - len(data)) |
| if not packet: |
| break |
| data += packet |
| return data |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Franka DP Policy Server (Native Interface)') |
| parser.add_argument('--dp_checkpoint', type=str, required=True, |
| help='Path to DP checkpoint') |
| parser.add_argument('--port', type=int, default=9966, |
| help='Server port') |
| parser.add_argument('--device', type=str, default='cuda', |
| help='Device for inference') |
| parser.add_argument('--task_name', type=str, default=None, |
| help='Task name for task-specific normalizer (e.g., tennis_bucket_upright)') |
|
|
| args = parser.parse_args() |
|
|
| print("="*70) |
| print("Franka Diffusion Policy Server (Native Interface)") |
| print("="*70) |
| print(f"Checkpoint: {args.dp_checkpoint}") |
| print(f"Task: {args.task_name or 'all_tasks (global normalizer)'}") |
| print(f"Port: {args.port}") |
| print(f"Device: {args.device}") |
| print("="*70) |
|
|
| |
| server = DPPolicyServer( |
| checkpoint_path=args.dp_checkpoint, |
| port=args.port, |
| device=args.device, |
| task_name=args.task_name |
| ) |
|
|
| server.run() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|