Reality123b commited on
Commit
885b5ec
·
verified ·
1 Parent(s): cd76a05

Add control.py

Browse files
Files changed (1) hide show
  1. fsd_model/control.py +396 -0
fsd_model/control.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Control Module for FSD Model.
3
+ Converts planned trajectory waypoints into actuator commands:
4
+ - Steering angle
5
+ - Throttle (acceleration)
6
+ - Brake
7
+ - Gear (forward/reverse/park)
8
+
9
+ Uses a combination of:
10
+ 1. PID controllers for smooth tracking
11
+ 2. Neural network for adaptive control
12
+ 3. Stanley controller for lateral control
13
+ 4. Bicycle model for vehicle dynamics
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from typing import Dict, Optional, Tuple
20
+ import math
21
+
22
+
23
+ class BicycleModel(nn.Module):
24
+ """
25
+ Kinematic bicycle model for vehicle dynamics simulation.
26
+ Used for both prediction and control.
27
+ State: [x, y, heading, speed]
28
+ Control: [steering_angle, acceleration]
29
+ """
30
+ def __init__(self, wheelbase: float = 2.7, dt: float = 0.1):
31
+ super().__init__()
32
+ self.wheelbase = wheelbase
33
+ self.dt = dt
34
+
35
+ def forward(
36
+ self, state: torch.Tensor, control: torch.Tensor
37
+ ) -> torch.Tensor:
38
+ """
39
+ Args:
40
+ state: (B, 4) [x, y, heading, speed]
41
+ control: (B, 2) [steering_angle, acceleration]
42
+ Returns:
43
+ next_state: (B, 4)
44
+ """
45
+ x, y, heading, speed = state[:, 0], state[:, 1], state[:, 2], state[:, 3]
46
+ steer, accel = control[:, 0], control[:, 1]
47
+
48
+ # Kinematic bicycle model equations
49
+ beta = torch.atan(0.5 * torch.tan(steer)) # slip angle
50
+
51
+ x_new = x + speed * torch.cos(heading + beta) * self.dt
52
+ y_new = y + speed * torch.sin(heading + beta) * self.dt
53
+ heading_new = heading + (speed / self.wheelbase) * torch.sin(beta) * self.dt
54
+ speed_new = speed + accel * self.dt
55
+
56
+ # Clamp speed to be non-negative
57
+ speed_new = torch.clamp(speed_new, min=0.0)
58
+
59
+ return torch.stack([x_new, y_new, heading_new, speed_new], dim=-1)
60
+
61
+
62
+ class PIDController(nn.Module):
63
+ """
64
+ Learnable PID controller with neural network gain scheduling.
65
+ Gains (Kp, Ki, Kd) are predicted based on current state.
66
+ """
67
+ def __init__(self, state_dim: int = 6, hidden_dim: int = 64):
68
+ super().__init__()
69
+
70
+ # Gain predictor network
71
+ self.gain_net = nn.Sequential(
72
+ nn.Linear(state_dim, hidden_dim),
73
+ nn.ReLU(),
74
+ nn.Linear(hidden_dim, hidden_dim),
75
+ nn.ReLU(),
76
+ nn.Linear(hidden_dim, 6), # Kp, Ki, Kd for lateral + longitudinal
77
+ nn.Softplus(), # Ensure positive gains
78
+ )
79
+
80
+ # Integral buffer (not a parameter, reset per episode)
81
+ self.register_buffer('integral_error', torch.zeros(1, 2))
82
+ self.register_buffer('prev_error', torch.zeros(1, 2))
83
+
84
+ def forward(
85
+ self,
86
+ error: torch.Tensor,
87
+ ego_state: torch.Tensor,
88
+ dt: float = 0.1,
89
+ ) -> torch.Tensor:
90
+ """
91
+ Args:
92
+ error: (B, 2) [lateral_error, longitudinal_error]
93
+ ego_state: (B, 6) current vehicle state
94
+ dt: time step
95
+ Returns:
96
+ control: (B, 2) [steering_correction, accel_correction]
97
+ """
98
+ B = error.shape[0]
99
+
100
+ # Predict adaptive gains
101
+ gains = self.gain_net(ego_state)
102
+ kp = gains[:, :2]
103
+ ki = gains[:, 2:4]
104
+ kd = gains[:, 4:6]
105
+
106
+ # PID computation
107
+ proportional = kp * error
108
+
109
+ # Integral (with anti-windup) — handle variable batch sizes
110
+ if self.integral_error.shape[0] != B:
111
+ self.integral_error = torch.zeros(B, 2, device=error.device)
112
+ if self.prev_error.shape[0] != B:
113
+ self.prev_error = torch.zeros(B, 2, device=error.device)
114
+
115
+ self.integral_error = self.integral_error + error * dt
116
+ self.integral_error = torch.clamp(self.integral_error, -10.0, 10.0)
117
+ integral = ki * self.integral_error
118
+
119
+ # Derivative
120
+ derivative = kd * (error - self.prev_error) / dt
121
+ self.prev_error = error.detach()
122
+
123
+ control = proportional + integral + derivative
124
+
125
+ return control
126
+
127
+ def reset(self):
128
+ """Reset integral and derivative buffers."""
129
+ self.integral_error.zero_()
130
+ self.prev_error.zero_()
131
+
132
+
133
+ class StanleyController(nn.Module):
134
+ """
135
+ Stanley lateral controller enhanced with learned parameters.
136
+ Computes steering angle based on:
137
+ 1. Heading error
138
+ 2. Cross-track error
139
+ """
140
+ def __init__(self, k_gain: float = 0.5, k_soft: float = 1.0):
141
+ super().__init__()
142
+ # Learnable gains
143
+ self.k_gain = nn.Parameter(torch.tensor(k_gain))
144
+ self.k_soft = nn.Parameter(torch.tensor(k_soft))
145
+
146
+ def forward(
147
+ self,
148
+ heading_error: torch.Tensor,
149
+ cross_track_error: torch.Tensor,
150
+ speed: torch.Tensor,
151
+ ) -> torch.Tensor:
152
+ """
153
+ Args:
154
+ heading_error: (B,) heading difference to path
155
+ cross_track_error: (B,) lateral distance to path
156
+ speed: (B,) current speed
157
+ Returns:
158
+ steering: (B,) desired steering angle (radians)
159
+ """
160
+ # Stanley formula
161
+ cross_track_steer = torch.atan2(
162
+ self.k_gain * cross_track_error,
163
+ speed + self.k_soft
164
+ )
165
+ steering = heading_error + cross_track_steer
166
+
167
+ # Clamp to max steering angle (~35 degrees)
168
+ max_steer = math.radians(35)
169
+ steering = torch.clamp(steering, -max_steer, max_steer)
170
+
171
+ return steering
172
+
173
+
174
+ class NeuralController(nn.Module):
175
+ """
176
+ End-to-end neural network controller.
177
+ Takes BEV features + ego state + waypoints and directly outputs
178
+ steering, throttle, brake commands.
179
+ Serves as a refinement on top of classical controllers.
180
+ """
181
+ def __init__(
182
+ self,
183
+ bev_channels: int = 256,
184
+ waypoint_dim: int = 4,
185
+ num_waypoints: int = 20,
186
+ ego_dim: int = 6,
187
+ hidden_dim: int = 256,
188
+ ):
189
+ super().__init__()
190
+
191
+ # BEV feature compression
192
+ self.bev_encoder = nn.Sequential(
193
+ nn.AdaptiveAvgPool2d(4),
194
+ nn.Flatten(),
195
+ nn.Linear(bev_channels * 16, hidden_dim),
196
+ nn.ReLU(),
197
+ )
198
+
199
+ # Waypoint encoder
200
+ self.waypoint_encoder = nn.Sequential(
201
+ nn.Flatten(),
202
+ nn.Linear(num_waypoints * waypoint_dim, hidden_dim),
203
+ nn.ReLU(),
204
+ )
205
+
206
+ # Ego state encoder
207
+ self.ego_encoder = nn.Sequential(
208
+ nn.Linear(ego_dim, hidden_dim // 2),
209
+ nn.ReLU(),
210
+ )
211
+
212
+ # Control output
213
+ self.control_head = nn.Sequential(
214
+ nn.Linear(hidden_dim * 2 + hidden_dim // 2, hidden_dim),
215
+ nn.ReLU(),
216
+ nn.Dropout(0.2),
217
+ nn.Linear(hidden_dim, 128),
218
+ nn.ReLU(),
219
+ nn.Linear(128, 3), # steering, throttle, brake
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ bev_features: torch.Tensor,
225
+ waypoints: torch.Tensor,
226
+ ego_state: torch.Tensor,
227
+ ) -> Dict[str, torch.Tensor]:
228
+ """
229
+ Returns:
230
+ Dict with steering (-1 to 1), throttle (0 to 1), brake (0 to 1)
231
+ """
232
+ bev_feat = self.bev_encoder(bev_features)
233
+ wp_feat = self.waypoint_encoder(waypoints)
234
+ ego_feat = self.ego_encoder(ego_state)
235
+
236
+ combined = torch.cat([bev_feat, wp_feat, ego_feat], dim=-1)
237
+ raw = self.control_head(combined)
238
+
239
+ steering = torch.tanh(raw[:, 0]) # [-1, 1]
240
+ throttle = torch.sigmoid(raw[:, 1]) # [0, 1]
241
+ brake = torch.sigmoid(raw[:, 2]) # [0, 1]
242
+
243
+ return {
244
+ "steering": steering,
245
+ "throttle": throttle,
246
+ "brake": brake,
247
+ }
248
+
249
+
250
+ class ControlModule(nn.Module):
251
+ """
252
+ Complete control module that combines:
253
+ 1. Neural controller (BEV-aware, end-to-end)
254
+ 2. Stanley controller (geometric lateral control)
255
+ 3. PID controller (error-based correction)
256
+ 4. Bicycle model (physics-based prediction)
257
+ 5. Safety limits enforcement
258
+ """
259
+ def __init__(
260
+ self,
261
+ bev_channels: int = 256,
262
+ num_waypoints: int = 20,
263
+ wheelbase: float = 2.7,
264
+ max_speed_ms: float = 8.94,
265
+ max_steering_deg: float = 35.0,
266
+ max_accel: float = 3.0,
267
+ max_decel: float = 8.0,
268
+ dt: float = 0.1,
269
+ ):
270
+ super().__init__()
271
+ self.max_speed_ms = max_speed_ms
272
+ self.max_steering = math.radians(max_steering_deg)
273
+ self.max_accel = max_accel
274
+ self.max_decel = max_decel
275
+ self.dt = dt
276
+
277
+ # Sub-controllers
278
+ self.neural_controller = NeuralController(
279
+ bev_channels=bev_channels,
280
+ num_waypoints=num_waypoints,
281
+ )
282
+ self.stanley_controller = StanleyController()
283
+ self.pid_controller = PIDController()
284
+ self.bicycle_model = BicycleModel(wheelbase, dt)
285
+
286
+ # Controller fusion weights (learned)
287
+ self.fusion_weights = nn.Sequential(
288
+ nn.Linear(6, 32), # ego state -> weights
289
+ nn.ReLU(),
290
+ nn.Linear(32, 3), # weights for [neural, stanley, pid]
291
+ nn.Softmax(dim=-1),
292
+ )
293
+
294
+ def forward(
295
+ self,
296
+ bev_features: torch.Tensor,
297
+ planned_waypoints: torch.Tensor,
298
+ ego_state: torch.Tensor,
299
+ emergency_brake: Optional[torch.Tensor] = None,
300
+ ) -> Dict[str, torch.Tensor]:
301
+ """
302
+ Args:
303
+ bev_features: (B, C, H, W) BEV features
304
+ planned_waypoints: (B, T, 4) [x, y, heading, speed]
305
+ ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y]
306
+ emergency_brake: (B, 1) emergency brake probability
307
+ Returns:
308
+ Dict with final actuator commands
309
+ """
310
+ B = ego_state.shape[0]
311
+ device = ego_state.device
312
+
313
+ # 1. Neural controller output
314
+ neural_out = self.neural_controller(bev_features, planned_waypoints, ego_state)
315
+
316
+ # 2. Stanley controller - compute from first waypoint error
317
+ next_wp = planned_waypoints[:, 0, :] # next waypoint
318
+ heading_error = next_wp[:, 2] - ego_state[:, 3] # yaw_rate as proxy
319
+ cross_track_error = torch.sqrt(
320
+ (next_wp[:, 0] - ego_state[:, 4])**2 +
321
+ (next_wp[:, 1] - ego_state[:, 5])**2
322
+ )
323
+ stanley_steer = self.stanley_controller(
324
+ heading_error, cross_track_error, ego_state[:, 0]
325
+ )
326
+
327
+ # 3. PID controller
328
+ lateral_err = cross_track_error
329
+ speed_err = next_wp[:, 3] - ego_state[:, 0]
330
+ pid_error = torch.stack([lateral_err, speed_err], dim=-1)
331
+ pid_out = self.pid_controller(pid_error, ego_state, self.dt)
332
+
333
+ # 4. Fuse controllers based on driving state
334
+ weights = self.fusion_weights(ego_state) # (B, 3)
335
+
336
+ # Neural steering + Stanley steering + PID steering
337
+ neural_steer = neural_out["steering"] * self.max_steering
338
+ final_steering = (
339
+ weights[:, 0] * neural_steer +
340
+ weights[:, 1] * stanley_steer +
341
+ weights[:, 2] * torch.clamp(pid_out[:, 0], -self.max_steering, self.max_steering)
342
+ )
343
+
344
+ # Throttle/brake from neural + PID
345
+ final_throttle = neural_out["throttle"]
346
+ final_brake = neural_out["brake"]
347
+
348
+ # PID speed correction
349
+ pid_accel = pid_out[:, 1]
350
+ final_throttle = final_throttle + torch.clamp(pid_accel, 0, 1) * weights[:, 2]
351
+ final_brake = final_brake + torch.clamp(-pid_accel, 0, 1) * weights[:, 2]
352
+
353
+ # 5. Safety overrides
354
+ if emergency_brake is not None:
355
+ emergency_mask = (emergency_brake.squeeze(-1) > 0.5).float()
356
+ final_throttle = final_throttle * (1 - emergency_mask)
357
+ final_brake = torch.max(final_brake, emergency_mask)
358
+
359
+ # Clamp all outputs
360
+ final_steering = torch.clamp(final_steering, -self.max_steering, self.max_steering)
361
+ final_throttle = torch.clamp(final_throttle, 0.0, 1.0)
362
+ final_brake = torch.clamp(final_brake, 0.0, 1.0)
363
+
364
+ # Mutual exclusion: can't throttle and brake simultaneously
365
+ # If braking > throttle, zero out throttle
366
+ brake_dominant = (final_brake > final_throttle).float()
367
+ final_throttle = final_throttle * (1 - brake_dominant)
368
+
369
+ # Convert to physical units
370
+ accel_cmd = final_throttle * self.max_accel - final_brake * self.max_decel
371
+ steer_deg = torch.rad2deg(final_steering)
372
+
373
+ # Predict next state using bicycle model
374
+ current_state = torch.stack([
375
+ ego_state[:, 4], # x
376
+ ego_state[:, 5], # y
377
+ ego_state[:, 3], # heading (yaw_rate as proxy)
378
+ ego_state[:, 0], # speed
379
+ ], dim=-1)
380
+
381
+ control_input = torch.stack([final_steering, accel_cmd], dim=-1)
382
+ predicted_next_state = self.bicycle_model(current_state, control_input)
383
+
384
+ return {
385
+ "steering_rad": final_steering,
386
+ "steering_deg": steer_deg,
387
+ "throttle": final_throttle,
388
+ "brake": final_brake,
389
+ "acceleration_cmd": accel_cmd,
390
+ "controller_weights": weights,
391
+ "predicted_next_state": predicted_next_state,
392
+ }
393
+
394
+ def reset(self):
395
+ """Reset controller states (call at start of new episode)."""
396
+ self.pid_controller.reset()