File size: 14,946 Bytes
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9078db7
 
6487de0
9078db7
 
 
 
 
 
 
 
6487de0
9078db7
 
 
 
 
 
 
 
 
 
 
 
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902cc59
 
 
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9078db7
 
 
 
 
 
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902cc59
6487de0
902cc59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902cc59
6487de0
902cc59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6487de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""Quadruped-PyMPC inspired Controller for Spot robot.

This controller uses components from Quadruped-PyMPC for:
- Periodic gait generation
- Swing trajectory planning

But outputs joint position targets (not torques) to work with SpotEnv's
position-controlled actuators.

Requires: quadruped_pympc, jax, gym_quadruped
"""
import numpy as np
import os
import importlib.util

# Import BaseController
_controllers_dir = os.path.dirname(os.path.abspath(__file__))
_base_path = os.path.join(_controllers_dir, "base.py")
_spec = importlib.util.spec_from_file_location("spot_base", _base_path)
_base_module = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_base_module)
BaseController = _base_module.BaseController

# Try to import Quadruped-PyMPC components
PYMPC_AVAILABLE = False

# First try local standalone implementation
try:
    import sys
    import os
    # Get the parent directory of controllers (the spot directory)
    spot_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    if spot_dir not in sys.path:
        sys.path.insert(0, spot_dir)

    from quadruped_deps import LegsAttr, GaitType, PeriodicGaitGenerator, mpc_params as pympc_config
    PYMPC_AVAILABLE = True
    print("Quadruped-PyMPC components loaded from local standalone implementation")
except ImportError as local_err:
    # Fall back to external packages
    try:
        from gym_quadruped.utils.quadruped_utils import LegsAttr
        from quadruped_pympc.helpers.periodic_gait_generator import PeriodicGaitGenerator
        from quadruped_pympc.helpers.quadruped_utils import GaitType
        import quadruped_pympc.config as pympc_config
        PYMPC_AVAILABLE = True
        print("Quadruped-PyMPC loaded from external packages")
    except ImportError as e:
        print(f"Quadruped-PyMPC not available (local or external): {e}")


class QuadrupedPyMPCController(BaseController):
    """
    Controller using Quadruped-PyMPC's gait generator for timing.

    Outputs position targets compatible with SpotEnv's position actuators.
    """

    STANDING_POSE = np.array([
        0.0, 1.04, -1.8,
        0.0, 1.04, -1.8,
        0.0, 1.04, -1.8,
        0.0, 1.04, -1.8,
    ], dtype=np.float32)

    # Leg indices matching PyMPC convention
    LEGS_ORDER = ('FL', 'FR', 'RL', 'RR')
    LEG_IDX = {'FL': 0, 'FR': 1, 'RL': 2, 'RR': 3}

    def __init__(self, num_joints: int = 12, model=None, data=None):
        super().__init__(num_joints)

        self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        self.model = model
        self.data = data
        self.pympc_initialized = False

        # Robot parameters (Spot)
        self.mass = 32.86
        self.hip_height = 0.46

        # Simulation parameters
        self.sim_dt = 0.002

        # Gait parameters
        self.step_freq = 2.5  # Hz - even faster stepping
        self.duty_factor = 0.55  # 55% stance time (more swing time)

        # Motion parameters - SIGNIFICANTLY INCREASED
        self.swing_height = 0.35  # radians - big knee lift during swing
        self.stride_length = 0.5  # radians - big hip pitch for forward/back
        self.strafe_amplitude = 0.6  # radians - lateral hip roll amplitude (INCREASED)
        self.strafe_knee_extra = 0.2  # extra knee lift during strafe
        self.turn_amplitude = 0.3  # radians - turning hip roll amplitude

        # Feedback gains for balance
        self.kp_roll = 0.5
        self.kp_pitch = 0.4
        self.kp_height = 2.5
        self.kd_roll = 0.1
        self.kd_pitch = 0.1
        self.target_height = 0.46

        # Command smoothing - VERY FAST response
        self.cmd_filter = 0.6
        self.filtered_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)

        # Initialize PyMPC components if available
        if PYMPC_AVAILABLE:
            self._init_pympc()

    def _init_pympc(self):
        """Initialize Quadruped-PyMPC gait generator."""
        try:
            # Configure PyMPC
            # Handle both dict-like and module-like config objects
            if hasattr(pympc_config, 'mpc_params'):
                horizon = pympc_config.mpc_params.get('horizon', 12)
            else:
                # Using local standalone implementation
                horizon = pympc_config.get('horizon', 12)

            # Initialize periodic gait generator for trot
            self.gait_generator = PeriodicGaitGenerator(
                duty_factor=self.duty_factor,
                step_freq=self.step_freq,
                gait_type=GaitType.TROT,
                horizon=horizon,
            )

            self.pympc_initialized = True
            print("PyMPC gait generator initialized")

        except Exception as e:
            print(f"Failed to initialize PyMPC components: {e}")
            import traceback
            traceback.print_exc()
            self.pympc_initialized = False

    def set_command(self, vx: float = 0.0, vy: float = 0.0, vyaw: float = 0.0):
        """Set velocity command."""
        self.cmd[0] = np.clip(vx, -1.0, 1.0)
        self.cmd[1] = np.clip(vy, -0.5, 0.5)
        self.cmd[2] = np.clip(vyaw, -1.0, 1.0)

    def get_command(self):
        """Get current velocity command."""
        return self.cmd.copy()

    def _get_leg_phase(self, leg_name: str) -> tuple:
        """Get swing/stance info for a leg from PyMPC gait generator."""
        leg_idx = self.LEG_IDX[leg_name]

        # Phase signal gives normalized phase in gait cycle [0, 1)
        phase = self.gait_generator.phase_signal[leg_idx]

        # Swing happens when phase < swing_duration
        swing_duration = 1.0 - self.duty_factor
        in_swing = phase < swing_duration

        # Normalize phase within current mode (swing or stance)
        if in_swing:
            # In swing: phase goes from 0 to swing_duration
            # Normalize to [0, 1)
            phase_in_mode = phase / swing_duration
        else:
            # In stance: phase goes from swing_duration to 1.0
            # Normalize to [0, 1)
            phase_in_mode = (phase - swing_duration) / self.duty_factor

        return in_swing, phase_in_mode

    def compute_action(self, obs: np.ndarray, data) -> np.ndarray:
        """Compute joint targets using PyMPC gait timing."""
        self.data = data

        # Smooth command transition
        self.filtered_cmd = (self.cmd_filter * self.filtered_cmd +
                            (1 - self.cmd_filter) * self.cmd)
        vx, vy, vyaw = self.filtered_cmd

        # Check if should walk
        speed = np.sqrt(vx**2 + vy**2 + vyaw**2)
        if speed < 0.05:
            return self.STANDING_POSE.copy()

        # If PyMPC not available, use simple fallback
        if not self.pympc_initialized:
            return self._simple_gait(obs)

        # Update gait generator
        self.gait_generator.run(self.sim_dt, self.step_freq)

        # Extract state from observation
        base_quat = obs[3:7]  # w, x, y, z
        base_ang_vel = obs[10:13]
        height = obs[2]

        # Compute roll and pitch from quaternion
        qw, qx, qy, qz = base_quat
        roll = np.arctan2(2 * (qw * qx + qy * qz), 1 - 2 * (qx**2 + qy**2))
        sin_pitch = np.clip(2 * (qw * qy - qz * qx), -1.0, 1.0)
        pitch = np.arcsin(sin_pitch)

        # Angular velocities
        roll_rate = base_ang_vel[0]
        pitch_rate = base_ang_vel[1]

        # Height error
        height_error = self.target_height - height

        # Compute feedback corrections
        roll_correction = -self.kp_roll * roll - self.kd_roll * roll_rate
        pitch_correction = -self.kp_pitch * pitch - self.kd_pitch * pitch_rate
        height_correction = self.kp_height * height_error

        # Output array
        target = np.zeros(12, dtype=np.float32)

        for leg_name in self.LEGS_ORDER:
            leg_idx = self.LEG_IDX[leg_name]
            in_swing, phase = self._get_leg_phase(leg_name)

            is_front = leg_name in ['FL', 'FR']
            is_left = leg_name in ['FL', 'RL']

            # Base joint positions
            hx = 0.0
            hy = 1.04  # standing hip pitch
            kn = -1.8  # standing knee

            if in_swing:
                # Swing phase: lift foot and swing forward
                swing_progress = phase

                # Hip pitch: swing leg forward/backward
                hy_offset = self.stride_length * vx * (1 - 2 * swing_progress)

                # Knee: lift in middle of swing (parabolic)
                kn_offset = -self.swing_height * np.sin(np.pi * swing_progress)

                # TURNING: Hip roll moves outward during swing
                if abs(vyaw) > 0.01:
                    # For turning, swing leg moves outward
                    sign = 1.0 if is_left else -1.0
                    # Differential: front/back legs turn opposite
                    turn_sign = 1.0 if is_front else -1.0
                    hx += sign * turn_sign * self.turn_amplitude * vyaw * np.sin(np.pi * swing_progress)
                    # Also adjust hip pitch for turning
                    hy_offset += turn_sign * 0.15 * vyaw * (1 - 2 * swing_progress)

                # STRAFE: Aggressive lateral motion during swing
                if abs(vy) > 0.01:
                    # Determine if this leg should step toward or away from strafe direction
                    # When strafing right (vy < 0): right legs lead, left legs follow
                    # When strafing left (vy > 0): left legs lead, right legs follow
                    strafe_dir = 1.0 if vy > 0 else -1.0
                    is_leading_leg = (is_left and vy > 0) or (not is_left and vy < 0)

                    # Leading leg: big outward hip roll to reach
                    # Trailing leg: hip roll inward to push
                    if is_leading_leg:
                        hx += self.strafe_amplitude * abs(vy) * np.sin(np.pi * swing_progress)
                    else:
                        hx -= self.strafe_amplitude * 0.5 * abs(vy) * np.sin(np.pi * swing_progress)

                    # Extra knee lift for clearance during strafe
                    kn_offset -= self.strafe_knee_extra * abs(vy) * np.sin(np.pi * swing_progress)

                    # Slight hip pitch adjustment for diagonal reach
                    hy_offset += 0.1 * abs(vy) * (1 - 2 * swing_progress)

                hy += hy_offset
                kn += kn_offset

            else:
                # Stance phase: push body in desired direction
                stance_progress = phase

                # Hip pitch: push backward for forward motion
                hy_offset = self.stride_length * vx * (2 * stance_progress - 1)

                # Apply balance feedback during stance
                kn_offset = height_correction * 0.3

                # Roll correction through hip roll
                if is_left:
                    hx += roll_correction
                else:
                    hx -= roll_correction

                # Pitch correction through hip pitch
                if is_front:
                    hy_offset -= pitch_correction * 0.5
                else:
                    hy_offset += pitch_correction * 0.5

                # TURNING in stance: rotate body
                if abs(vyaw) > 0.01:
                    turn_sign = 1.0 if is_front else -1.0
                    sign = 1.0 if is_left else -1.0
                    # Push in opposite direction during stance
                    hx += sign * turn_sign * self.turn_amplitude * vyaw * np.cos(np.pi * stance_progress)

                # STRAFE in stance: push body laterally
                if abs(vy) > 0.01:
                    # In stance, push in direction of strafe
                    is_leading_leg = (is_left and vy > 0) or (not is_left and vy < 0)

                    # Leading leg (already placed outward): pull body toward it
                    # Trailing leg: push body away
                    if is_leading_leg:
                        # Pull - hip roll decreases to bring body over
                        hx -= self.strafe_amplitude * 0.7 * abs(vy) * (1 - stance_progress)
                    else:
                        # Push - hip roll increases to push body away
                        hx += self.strafe_amplitude * 0.7 * abs(vy) * stance_progress

                    # Also slight knee bend on trailing side for push
                    if not is_leading_leg:
                        kn_offset -= 0.1 * abs(vy) * stance_progress

                hy += hy_offset
                kn += kn_offset

            # Clamp to joint limits (from spot.xml)
            hx = np.clip(hx, -0.785, 0.785)
            hy = np.clip(hy, -0.9, 2.29)
            kn = np.clip(kn, -2.79, -0.25)

            # Set joint targets
            base_idx = leg_idx * 3
            target[base_idx + 0] = hx
            target[base_idx + 1] = hy
            target[base_idx + 2] = kn

        return target

    def _simple_gait(self, obs):
        """Fallback simple trot gait when PyMPC not available."""
        vx, vy, vyaw = self.filtered_cmd

        phase = (self.time * self.step_freq) % 1.0
        phase_offsets = np.array([0.0, 0.5, 0.5, 0.0])  # Trot pattern

        target = np.zeros(12, dtype=np.float32)

        for leg_idx in range(4):
            leg_phase = (phase + phase_offsets[leg_idx]) % 1.0
            swing_duration = 1.0 - self.duty_factor
            in_swing = leg_phase < swing_duration

            is_left = leg_idx in [0, 2]  # FL, RL

            hx = 0.0
            hy = 1.04
            kn = -1.8

            if in_swing:
                p = leg_phase / swing_duration
                hy += self.stride_length * vx * (1 - 2 * p)
                kn += -self.swing_height * np.sin(np.pi * p)
            else:
                p = (leg_phase - swing_duration) / self.duty_factor
                hy += self.stride_length * vx * (2 * p - 1)

            # Balance feedback
            base_quat = obs[3:7]
            qw, qx, qy, qz = base_quat
            roll = np.arctan2(2 * (qw * qx + qy * qz), 1 - 2 * (qx**2 + qy**2))

            if is_left:
                hx -= 0.3 * roll
            else:
                hx += 0.3 * roll

            base_idx = leg_idx * 3
            target[base_idx + 0] = np.clip(hx, -0.785, 0.785)
            target[base_idx + 1] = np.clip(hy, -0.9, 2.29)
            target[base_idx + 2] = np.clip(kn, -2.79, -0.25)

        return target

    def reset(self):
        super().reset()
        self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        self.filtered_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)

        if self.pympc_initialized and hasattr(self, 'gait_generator'):
            self.gait_generator.reset()

    @property
    def name(self) -> str:
        status = "PyMPC" if self.pympc_initialized else "Fallback"
        return f"PyMPC Gait ({status}) (cmd: {self.cmd[0]:.1f}, {self.cmd[1]:.1f}, {self.cmd[2]:.1f})"