Georg commited on
Commit ·
3a6146c
1
Parent(s): 55d7c05
Refactor homing control logic in mujoco_server.py for improved accuracy and responsiveness
Browse files- Adjusted joint velocity calculations to prevent overshoot and ensure smoother homing behavior.
- Introduced tighter tolerance checks for stopping individual joints, enhancing precision during the homing process.
- Removed redundant debug print statements to streamline code and improve clarity.
- Updated velocity damping mechanism to gradually reduce speed as joints approach their target positions, ensuring a more controlled movement.
- mujoco_server.py +40 -22
mujoco_server.py
CHANGED
|
@@ -1175,14 +1175,12 @@ def handle_ws_message(ws, data):
|
|
| 1175 |
with mujoco_lock, homing_lock:
|
| 1176 |
# Update last home message timestamp
|
| 1177 |
homing_state["last_home_time"] = time.time()
|
| 1178 |
-
print(f"[Home] Processing home message, robot={current_robot}, env={env is not None}")
|
| 1179 |
if env is not None and current_robot in ("ur5", "ur5_t_push"):
|
| 1180 |
home_pose = getattr(env, "home_pose", None)
|
| 1181 |
get_joints = getattr(env, "get_joint_positions", None)
|
| 1182 |
start_multi_joint_jog = getattr(env, "start_multi_joint_jog", None)
|
| 1183 |
stop_jog = getattr(env, "stop_jog", None)
|
| 1184 |
|
| 1185 |
-
print(f"[Home] home_pose={home_pose is not None}, get_joints={callable(get_joints)}, multi_jog={callable(start_multi_joint_jog)}, stop={callable(stop_jog)}")
|
| 1186 |
if home_pose is not None and callable(get_joints) and callable(start_multi_joint_jog) and callable(stop_jog):
|
| 1187 |
if hasattr(home_pose, "tolist"):
|
| 1188 |
target_positions = home_pose.tolist()
|
|
@@ -1191,7 +1189,6 @@ def handle_ws_message(ws, data):
|
|
| 1191 |
|
| 1192 |
current_positions = get_joints()
|
| 1193 |
tolerance = 0.01 # rad (~0.57 degrees)
|
| 1194 |
-
print(f"[Home] target={target_positions[:3]}, current={current_positions[:3]}")
|
| 1195 |
|
| 1196 |
# Calculate errors for all joints
|
| 1197 |
errors = []
|
|
@@ -1212,6 +1209,7 @@ def handle_ws_message(ws, data):
|
|
| 1212 |
homing_state["is_jogging"] = False
|
| 1213 |
homing_state["active"] = False
|
| 1214 |
homing_state.pop("previous_errors", None)
|
|
|
|
| 1215 |
return
|
| 1216 |
|
| 1217 |
# Initialize previous errors tracking if needed
|
|
@@ -1219,40 +1217,67 @@ def handle_ws_message(ws, data):
|
|
| 1219 |
homing_state["previous_errors"] = [0.0] * len(errors)
|
| 1220 |
|
| 1221 |
# Calculate proportional velocities for all joints
|
| 1222 |
-
base_velocity = 0.
|
| 1223 |
velocities = []
|
| 1224 |
|
|
|
|
|
|
|
|
|
|
| 1225 |
for joint_idx, error in enumerate(errors):
|
| 1226 |
abs_error = abs(error)
|
| 1227 |
|
| 1228 |
# Check for overshoot (sign change)
|
| 1229 |
prev_error = homing_state["previous_errors"][joint_idx]
|
| 1230 |
if prev_error != 0 and ((prev_error > 0 and error < 0) or (prev_error < 0 and error > 0)):
|
| 1231 |
-
# Overshoot detected -
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
base_velocity_joint = base_velocity
|
| 1235 |
|
| 1236 |
-
if abs_error <
|
| 1237 |
-
#
|
| 1238 |
velocities.append(0.0)
|
| 1239 |
else:
|
| 1240 |
# Proportional control: scale velocity by distance relative to max error
|
| 1241 |
# This ensures all joints reach home at approximately the same time
|
| 1242 |
if max_abs_error > 0:
|
| 1243 |
-
velocity =
|
| 1244 |
else:
|
| 1245 |
velocity = 0.0
|
| 1246 |
|
| 1247 |
-
# Apply
|
| 1248 |
-
|
| 1249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1250 |
else:
|
| 1251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1252 |
|
| 1253 |
# Store current errors for overshoot detection
|
| 1254 |
homing_state["previous_errors"] = errors
|
| 1255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1256 |
# Start multi-axis jogging with calculated velocities
|
| 1257 |
current_velocities = homing_state.get("current_velocities")
|
| 1258 |
|
|
@@ -1521,9 +1546,6 @@ def _create_env(robot, scene):
|
|
| 1521 |
@sock.route(f'{API_PREFIX}/ws')
|
| 1522 |
def websocket_handler(ws):
|
| 1523 |
"""Handle WebSocket connections."""
|
| 1524 |
-
with open('/tmp/ws_debug.log', 'a') as f:
|
| 1525 |
-
f.write('WebSocket client connected\n')
|
| 1526 |
-
f.flush()
|
| 1527 |
print('WebSocket client connected', flush=True)
|
| 1528 |
|
| 1529 |
# Register client (track all clients, even before they identify)
|
|
@@ -1541,10 +1563,6 @@ def websocket_handler(ws):
|
|
| 1541 |
break
|
| 1542 |
try:
|
| 1543 |
data = json.loads(message)
|
| 1544 |
-
with open('/tmp/ws_debug.log', 'a') as f:
|
| 1545 |
-
f.write(f"Received message type: {data.get('type', 'unknown')}\n")
|
| 1546 |
-
f.flush()
|
| 1547 |
-
print(f"[WS] Received message type: {data.get('type', 'unknown')}")
|
| 1548 |
handle_ws_message(ws, data)
|
| 1549 |
except json.JSONDecodeError:
|
| 1550 |
print(f"Invalid JSON received: {message}")
|
|
|
|
| 1175 |
with mujoco_lock, homing_lock:
|
| 1176 |
# Update last home message timestamp
|
| 1177 |
homing_state["last_home_time"] = time.time()
|
|
|
|
| 1178 |
if env is not None and current_robot in ("ur5", "ur5_t_push"):
|
| 1179 |
home_pose = getattr(env, "home_pose", None)
|
| 1180 |
get_joints = getattr(env, "get_joint_positions", None)
|
| 1181 |
start_multi_joint_jog = getattr(env, "start_multi_joint_jog", None)
|
| 1182 |
stop_jog = getattr(env, "stop_jog", None)
|
| 1183 |
|
|
|
|
| 1184 |
if home_pose is not None and callable(get_joints) and callable(start_multi_joint_jog) and callable(stop_jog):
|
| 1185 |
if hasattr(home_pose, "tolist"):
|
| 1186 |
target_positions = home_pose.tolist()
|
|
|
|
| 1189 |
|
| 1190 |
current_positions = get_joints()
|
| 1191 |
tolerance = 0.01 # rad (~0.57 degrees)
|
|
|
|
| 1192 |
|
| 1193 |
# Calculate errors for all joints
|
| 1194 |
errors = []
|
|
|
|
| 1209 |
homing_state["is_jogging"] = False
|
| 1210 |
homing_state["active"] = False
|
| 1211 |
homing_state.pop("previous_errors", None)
|
| 1212 |
+
print("[Homing] Completed - all joints at target")
|
| 1213 |
return
|
| 1214 |
|
| 1215 |
# Initialize previous errors tracking if needed
|
|
|
|
| 1217 |
homing_state["previous_errors"] = [0.0] * len(errors)
|
| 1218 |
|
| 1219 |
# Calculate proportional velocities for all joints
|
| 1220 |
+
base_velocity = 0.15 # rad/s - slower to prevent overshoot
|
| 1221 |
velocities = []
|
| 1222 |
|
| 1223 |
+
# Tighter tolerance for stopping individual joints
|
| 1224 |
+
stop_tolerance = tolerance * 0.5 # Half the main tolerance
|
| 1225 |
+
|
| 1226 |
for joint_idx, error in enumerate(errors):
|
| 1227 |
abs_error = abs(error)
|
| 1228 |
|
| 1229 |
# Check for overshoot (sign change)
|
| 1230 |
prev_error = homing_state["previous_errors"][joint_idx]
|
| 1231 |
if prev_error != 0 and ((prev_error > 0 and error < 0) or (prev_error < 0 and error > 0)):
|
| 1232 |
+
# Overshoot detected - stop this joint completely
|
| 1233 |
+
velocities.append(0.0)
|
| 1234 |
+
continue
|
|
|
|
| 1235 |
|
| 1236 |
+
if abs_error < stop_tolerance:
|
| 1237 |
+
# Close enough - stop this joint
|
| 1238 |
velocities.append(0.0)
|
| 1239 |
else:
|
| 1240 |
# Proportional control: scale velocity by distance relative to max error
|
| 1241 |
# This ensures all joints reach home at approximately the same time
|
| 1242 |
if max_abs_error > 0:
|
| 1243 |
+
velocity = base_velocity * (abs_error / max_abs_error)
|
| 1244 |
else:
|
| 1245 |
velocity = 0.0
|
| 1246 |
|
| 1247 |
+
# Apply additional damping when very close to target
|
| 1248 |
+
# Gradually reduce velocity as we approach the target
|
| 1249 |
+
damping_zone = tolerance * 3.0 # Start damping at 3x tolerance
|
| 1250 |
+
if abs_error < damping_zone:
|
| 1251 |
+
# Quadratic damping: velocity reduces rapidly near target
|
| 1252 |
+
damping_factor = (abs_error / damping_zone) ** 2
|
| 1253 |
+
velocity *= damping_factor
|
| 1254 |
+
|
| 1255 |
+
# Minimum velocity threshold - below this, just stop
|
| 1256 |
+
min_velocity = 0.02 # rad/s
|
| 1257 |
+
if velocity < min_velocity:
|
| 1258 |
+
velocities.append(0.0)
|
| 1259 |
else:
|
| 1260 |
+
# Apply velocity in correct direction
|
| 1261 |
+
if error > 0:
|
| 1262 |
+
velocities.append(velocity)
|
| 1263 |
+
else:
|
| 1264 |
+
velocities.append(-velocity)
|
| 1265 |
|
| 1266 |
# Store current errors for overshoot detection
|
| 1267 |
homing_state["previous_errors"] = errors
|
| 1268 |
|
| 1269 |
+
# Check if all velocities are zero (all joints at target or stopped due to overshoot)
|
| 1270 |
+
if all(abs(v) < 0.001 for v in velocities):
|
| 1271 |
+
# All joints stopped - end homing
|
| 1272 |
+
if homing_state.get("is_jogging", False):
|
| 1273 |
+
stop_jog()
|
| 1274 |
+
homing_state["is_jogging"] = False
|
| 1275 |
+
homing_state["active"] = False
|
| 1276 |
+
homing_state.pop("previous_errors", None)
|
| 1277 |
+
homing_state.pop("current_velocities", None)
|
| 1278 |
+
print("[Homing] All joints reached target or stopped")
|
| 1279 |
+
return
|
| 1280 |
+
|
| 1281 |
# Start multi-axis jogging with calculated velocities
|
| 1282 |
current_velocities = homing_state.get("current_velocities")
|
| 1283 |
|
|
|
|
| 1546 |
@sock.route(f'{API_PREFIX}/ws')
|
| 1547 |
def websocket_handler(ws):
|
| 1548 |
"""Handle WebSocket connections."""
|
|
|
|
|
|
|
|
|
|
| 1549 |
print('WebSocket client connected', flush=True)
|
| 1550 |
|
| 1551 |
# Register client (track all clients, even before they identify)
|
|
|
|
| 1563 |
break
|
| 1564 |
try:
|
| 1565 |
data = json.loads(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
handle_ws_message(ws, data)
|
| 1567 |
except json.JSONDecodeError:
|
| 1568 |
print(f"Invalid JSON received: {message}")
|