Georg commited on
Commit
3a6146c
·
1 Parent(s): 55d7c05

Refactor homing control logic in mujoco_server.py for improved accuracy and responsiveness

Browse files

- Adjusted joint velocity calculations to prevent overshoot and ensure smoother homing behavior.
- Introduced tighter tolerance checks for stopping individual joints, enhancing precision during the homing process.
- Removed redundant debug print statements to streamline code and improve clarity.
- Updated velocity damping mechanism to gradually reduce speed as joints approach their target positions, ensuring a more controlled movement.

Files changed (1) hide show
  1. mujoco_server.py +40 -22
mujoco_server.py CHANGED
@@ -1175,14 +1175,12 @@ def handle_ws_message(ws, data):
1175
  with mujoco_lock, homing_lock:
1176
  # Update last home message timestamp
1177
  homing_state["last_home_time"] = time.time()
1178
- print(f"[Home] Processing home message, robot={current_robot}, env={env is not None}")
1179
  if env is not None and current_robot in ("ur5", "ur5_t_push"):
1180
  home_pose = getattr(env, "home_pose", None)
1181
  get_joints = getattr(env, "get_joint_positions", None)
1182
  start_multi_joint_jog = getattr(env, "start_multi_joint_jog", None)
1183
  stop_jog = getattr(env, "stop_jog", None)
1184
 
1185
- print(f"[Home] home_pose={home_pose is not None}, get_joints={callable(get_joints)}, multi_jog={callable(start_multi_joint_jog)}, stop={callable(stop_jog)}")
1186
  if home_pose is not None and callable(get_joints) and callable(start_multi_joint_jog) and callable(stop_jog):
1187
  if hasattr(home_pose, "tolist"):
1188
  target_positions = home_pose.tolist()
@@ -1191,7 +1189,6 @@ def handle_ws_message(ws, data):
1191
 
1192
  current_positions = get_joints()
1193
  tolerance = 0.01 # rad (~0.57 degrees)
1194
- print(f"[Home] target={target_positions[:3]}, current={current_positions[:3]}")
1195
 
1196
  # Calculate errors for all joints
1197
  errors = []
@@ -1212,6 +1209,7 @@ def handle_ws_message(ws, data):
1212
  homing_state["is_jogging"] = False
1213
  homing_state["active"] = False
1214
  homing_state.pop("previous_errors", None)
 
1215
  return
1216
 
1217
  # Initialize previous errors tracking if needed
@@ -1219,40 +1217,67 @@ def handle_ws_message(ws, data):
1219
  homing_state["previous_errors"] = [0.0] * len(errors)
1220
 
1221
  # Calculate proportional velocities for all joints
1222
- base_velocity = 0.3 # rad/s (increased from 0.1 for faster homing)
1223
  velocities = []
1224
 
 
 
 
1225
  for joint_idx, error in enumerate(errors):
1226
  abs_error = abs(error)
1227
 
1228
  # Check for overshoot (sign change)
1229
  prev_error = homing_state["previous_errors"][joint_idx]
1230
  if prev_error != 0 and ((prev_error > 0 and error < 0) or (prev_error < 0 and error > 0)):
1231
- # Overshoot detected - reduce velocity
1232
- base_velocity_joint = base_velocity * 0.5
1233
- else:
1234
- base_velocity_joint = base_velocity
1235
 
1236
- if abs_error < tolerance:
1237
- # Already at target
1238
  velocities.append(0.0)
1239
  else:
1240
  # Proportional control: scale velocity by distance relative to max error
1241
  # This ensures all joints reach home at approximately the same time
1242
  if max_abs_error > 0:
1243
- velocity = base_velocity_joint * (abs_error / max_abs_error)
1244
  else:
1245
  velocity = 0.0
1246
 
1247
- # Apply velocity in correct direction
1248
- if error > 0:
1249
- velocities.append(velocity)
 
 
 
 
 
 
 
 
 
1250
  else:
1251
- velocities.append(-velocity)
 
 
 
 
1252
 
1253
  # Store current errors for overshoot detection
1254
  homing_state["previous_errors"] = errors
1255
 
 
 
 
 
 
 
 
 
 
 
 
 
1256
  # Start multi-axis jogging with calculated velocities
1257
  current_velocities = homing_state.get("current_velocities")
1258
 
@@ -1521,9 +1546,6 @@ def _create_env(robot, scene):
1521
  @sock.route(f'{API_PREFIX}/ws')
1522
  def websocket_handler(ws):
1523
  """Handle WebSocket connections."""
1524
- with open('/tmp/ws_debug.log', 'a') as f:
1525
- f.write('WebSocket client connected\n')
1526
- f.flush()
1527
  print('WebSocket client connected', flush=True)
1528
 
1529
  # Register client (track all clients, even before they identify)
@@ -1541,10 +1563,6 @@ def websocket_handler(ws):
1541
  break
1542
  try:
1543
  data = json.loads(message)
1544
- with open('/tmp/ws_debug.log', 'a') as f:
1545
- f.write(f"Received message type: {data.get('type', 'unknown')}\n")
1546
- f.flush()
1547
- print(f"[WS] Received message type: {data.get('type', 'unknown')}")
1548
  handle_ws_message(ws, data)
1549
  except json.JSONDecodeError:
1550
  print(f"Invalid JSON received: {message}")
 
1175
  with mujoco_lock, homing_lock:
1176
  # Update last home message timestamp
1177
  homing_state["last_home_time"] = time.time()
 
1178
  if env is not None and current_robot in ("ur5", "ur5_t_push"):
1179
  home_pose = getattr(env, "home_pose", None)
1180
  get_joints = getattr(env, "get_joint_positions", None)
1181
  start_multi_joint_jog = getattr(env, "start_multi_joint_jog", None)
1182
  stop_jog = getattr(env, "stop_jog", None)
1183
 
 
1184
  if home_pose is not None and callable(get_joints) and callable(start_multi_joint_jog) and callable(stop_jog):
1185
  if hasattr(home_pose, "tolist"):
1186
  target_positions = home_pose.tolist()
 
1189
 
1190
  current_positions = get_joints()
1191
  tolerance = 0.01 # rad (~0.57 degrees)
 
1192
 
1193
  # Calculate errors for all joints
1194
  errors = []
 
1209
  homing_state["is_jogging"] = False
1210
  homing_state["active"] = False
1211
  homing_state.pop("previous_errors", None)
1212
+ print("[Homing] Completed - all joints at target")
1213
  return
1214
 
1215
  # Initialize previous errors tracking if needed
 
1217
  homing_state["previous_errors"] = [0.0] * len(errors)
1218
 
1219
  # Calculate proportional velocities for all joints
1220
+ base_velocity = 0.15 # rad/s - slower to prevent overshoot
1221
  velocities = []
1222
 
1223
+ # Tighter tolerance for stopping individual joints
1224
+ stop_tolerance = tolerance * 0.5 # Half the main tolerance
1225
+
1226
  for joint_idx, error in enumerate(errors):
1227
  abs_error = abs(error)
1228
 
1229
  # Check for overshoot (sign change)
1230
  prev_error = homing_state["previous_errors"][joint_idx]
1231
  if prev_error != 0 and ((prev_error > 0 and error < 0) or (prev_error < 0 and error > 0)):
1232
+ # Overshoot detected - stop this joint completely
1233
+ velocities.append(0.0)
1234
+ continue
 
1235
 
1236
+ if abs_error < stop_tolerance:
1237
+ # Close enough - stop this joint
1238
  velocities.append(0.0)
1239
  else:
1240
  # Proportional control: scale velocity by distance relative to max error
1241
  # This ensures all joints reach home at approximately the same time
1242
  if max_abs_error > 0:
1243
+ velocity = base_velocity * (abs_error / max_abs_error)
1244
  else:
1245
  velocity = 0.0
1246
 
1247
+ # Apply additional damping when very close to target
1248
+ # Gradually reduce velocity as we approach the target
1249
+ damping_zone = tolerance * 3.0 # Start damping at 3x tolerance
1250
+ if abs_error < damping_zone:
1251
+ # Quadratic damping: velocity reduces rapidly near target
1252
+ damping_factor = (abs_error / damping_zone) ** 2
1253
+ velocity *= damping_factor
1254
+
1255
+ # Minimum velocity threshold - below this, just stop
1256
+ min_velocity = 0.02 # rad/s
1257
+ if velocity < min_velocity:
1258
+ velocities.append(0.0)
1259
  else:
1260
+ # Apply velocity in correct direction
1261
+ if error > 0:
1262
+ velocities.append(velocity)
1263
+ else:
1264
+ velocities.append(-velocity)
1265
 
1266
  # Store current errors for overshoot detection
1267
  homing_state["previous_errors"] = errors
1268
 
1269
+ # Check if all velocities are zero (all joints at target or stopped due to overshoot)
1270
+ if all(abs(v) < 0.001 for v in velocities):
1271
+ # All joints stopped - end homing
1272
+ if homing_state.get("is_jogging", False):
1273
+ stop_jog()
1274
+ homing_state["is_jogging"] = False
1275
+ homing_state["active"] = False
1276
+ homing_state.pop("previous_errors", None)
1277
+ homing_state.pop("current_velocities", None)
1278
+ print("[Homing] All joints reached target or stopped")
1279
+ return
1280
+
1281
  # Start multi-axis jogging with calculated velocities
1282
  current_velocities = homing_state.get("current_velocities")
1283
 
 
1546
  @sock.route(f'{API_PREFIX}/ws')
1547
  def websocket_handler(ws):
1548
  """Handle WebSocket connections."""
 
 
 
1549
  print('WebSocket client connected', flush=True)
1550
 
1551
  # Register client (track all clients, even before they identify)
 
1563
  break
1564
  try:
1565
  data = json.loads(message)
 
 
 
 
1566
  handle_ws_message(ws, data)
1567
  except json.JSONDecodeError:
1568
  print(f"Invalid JSON received: {message}")