chrisjcc commited on
Commit
f4dbc3b
·
verified ·
1 Parent(s): ec90f47

Using LLMPhy paper's definition of stability to implement classify orientation stability

Browse files
Files changed (1) hide show
  1. tray_sim.py +51 -0
tray_sim.py CHANGED
@@ -16,6 +16,57 @@ SIM_STEPS = 200
16
  IMPACT_STEP = 60 # is a good starting point, just after the pusher activates.
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def classify_stability(data, model):
20
  """
21
  Classify object stability based on position data.
 
16
  IMPACT_STEP = 60 # is a good starting point, just after the pusher activates.
17
 
18
 
19
+ def quat_to_rotmat(q):
20
+ """Convert quaternion to rotation matrix."""
21
+ w, x, y, z = q
22
+ return np.array([
23
+ [1 - 2*y**2 - 2*z**2, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
24
+ [2*x*y + 2*z*w, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*x*w],
25
+ [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x**2 - 2*y**2]
26
+ ])
27
+
28
+ def classify_orientation_stability(data, model, alpha_deg=45):
29
+ """
30
+ Check object stability using orientation.
31
+ Returns a list of booleans, True if the object is stable.
32
+ """
33
+ stable_objects = []
34
+ cos_alpha = np.cos(np.deg2rad(alpha_deg))
35
+ g = np.array([0, 0, -1])
36
+
37
+ for i in range(N_OBJECTS):
38
+ q = data.qpos[i*7 + 3 : i*7 + 7]
39
+ rotmat = quat_to_rotmat(q)
40
+ obj_z_world = rotmat[:, 2] # local z-axis in world coords
41
+ cos_theta = -np.dot(g, obj_z_world) # because gravity is downward
42
+ stable = cos_theta > cos_alpha
43
+ stable_objects.append(stable)
44
+
45
+ return stable_objects
46
+
47
+ def classify_orientation_stability_verbose(data, model, alpha_deg=45):
48
+ results = []
49
+ cos_alpha = np.cos(np.deg2rad(alpha_deg))
50
+ g = np.array([0, 0, -1])
51
+
52
+ for i in range(N_OBJECTS):
53
+ q = data.qpos[i*7 + 3 : i*7 + 7]
54
+ rotmat = quat_to_rotmat(q)
55
+ obj_z_world = rotmat[:, 2]
56
+ cos_theta = -np.dot(g, obj_z_world)
57
+
58
+ if cos_theta > cos_alpha:
59
+ results.append({"stable": True, "reason": "upright (tilt within limit)"})
60
+ else:
61
+ angle_deg = np.rad2deg(np.arccos(cos_theta))
62
+ results.append({
63
+ "stable": False,
64
+ "reason": f"tilt angle {angle_deg:.2f}° exceeds {alpha_deg}°"
65
+ })
66
+
67
+ return results
68
+
69
+
70
  def classify_stability(data, model):
71
  """
72
  Classify object stability based on position data.