Reality123b commited on
Commit
cd76a05
·
verified ·
1 Parent(s): 87321df

Add planning.py

Browse files
Files changed (1) hide show
  1. fsd_model/planning.py +335 -0
fsd_model/planning.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Planning Module for FSD Model.
3
+ Handles:
4
+ 1. Route Planning (high-level waypoints from navigation)
5
+ 2. Behavior Planning (lane changes, turns, stops, yields)
6
+ 3. Trajectory Planning (smooth, collision-free path generation)
7
+ 4. Safety Verification (collision checking, emergency braking)
8
+
9
+ Architecture: Transformer-based planner that attends to perception features
10
+ and produces waypoint trajectories. Inspired by UniAD and VAD planners.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from typing import Dict, List, Optional, Tuple
17
+ import math
18
+
19
+
20
+ class PositionalEncoding2D(nn.Module):
21
+ """2D sinusoidal positional encoding for BEV features."""
22
+ def __init__(self, channels: int):
23
+ super().__init__()
24
+ self.channels = channels
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ B, C, H, W = x.shape
28
+ device = x.device
29
+
30
+ y_pos = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) / H
31
+ x_pos = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) / W
32
+
33
+ dim = torch.arange(0, self.channels // 4, device=device).float()
34
+ dim = 10000 ** (2 * dim / (self.channels // 2))
35
+
36
+ pe = torch.zeros(self.channels, H, W, device=device)
37
+ quarter = self.channels // 4
38
+ pe[0:quarter] = torch.sin(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
39
+ pe[quarter:2*quarter] = torch.cos(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
40
+ pe[2*quarter:3*quarter] = torch.sin(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
41
+ pe[3*quarter:4*quarter] = torch.cos(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
42
+
43
+ return x + pe.unsqueeze(0).expand(B, -1, -1, -1)
44
+
45
+
46
+ class BehaviorPredictor(nn.Module):
47
+ """
48
+ Predicts high-level driving behavior/command.
49
+ Commands: keep_lane, turn_left, turn_right, lane_change_left,
50
+ lane_change_right, stop, yield, park, reverse, emergency_stop
51
+ """
52
+ def __init__(self, in_channels: int = 256, num_behaviors: int = 10):
53
+ super().__init__()
54
+ self.num_behaviors = num_behaviors
55
+
56
+ self.encoder = nn.Sequential(
57
+ nn.AdaptiveAvgPool2d(8),
58
+ nn.Flatten(),
59
+ nn.Linear(in_channels * 64, 512),
60
+ nn.ReLU(),
61
+ nn.Dropout(0.3),
62
+ nn.Linear(512, 256),
63
+ nn.ReLU(),
64
+ nn.Dropout(0.2),
65
+ nn.Linear(256, num_behaviors),
66
+ )
67
+
68
+ def forward(self, bev: torch.Tensor) -> torch.Tensor:
69
+ """Returns: (B, num_behaviors) logits"""
70
+ return self.encoder(bev)
71
+
72
+
73
+ class TrajectoryTransformer(nn.Module):
74
+ """
75
+ Transformer-based trajectory planner.
76
+ Generates waypoints by attending to BEV features and navigation commands.
77
+ Uses learnable trajectory queries (similar to DETR object queries).
78
+ """
79
+ def __init__(
80
+ self,
81
+ bev_channels: int = 256,
82
+ d_model: int = 256,
83
+ nhead: int = 8,
84
+ num_decoder_layers: int = 6,
85
+ num_waypoints: int = 20, # planning horizon waypoints
86
+ dim_feedforward: int = 1024,
87
+ dropout: float = 0.1,
88
+ ):
89
+ super().__init__()
90
+ self.num_waypoints = num_waypoints
91
+ self.d_model = d_model
92
+
93
+ # BEV feature compression
94
+ self.bev_compress = nn.Sequential(
95
+ nn.Conv2d(bev_channels, d_model, 1),
96
+ nn.BatchNorm2d(d_model),
97
+ nn.ReLU(),
98
+ )
99
+ self.pos_encoding = PositionalEncoding2D(d_model)
100
+
101
+ # Learnable trajectory queries
102
+ self.trajectory_queries = nn.Parameter(
103
+ torch.randn(num_waypoints, d_model)
104
+ )
105
+
106
+ # Navigation command embedding (high-level route)
107
+ self.command_embed = nn.Embedding(10, d_model) # 10 possible commands
108
+
109
+ # Ego state embedding (speed, acceleration, steering)
110
+ self.ego_state_embed = nn.Sequential(
111
+ nn.Linear(6, d_model), # speed, accel, steer, yaw_rate, x, y
112
+ nn.ReLU(),
113
+ )
114
+
115
+ # Transformer decoder
116
+ decoder_layer = nn.TransformerDecoderLayer(
117
+ d_model=d_model,
118
+ nhead=nhead,
119
+ dim_feedforward=dim_feedforward,
120
+ dropout=dropout,
121
+ batch_first=True,
122
+ )
123
+ self.transformer_decoder = nn.TransformerDecoder(
124
+ decoder_layer, num_layers=num_decoder_layers
125
+ )
126
+
127
+ # Waypoint prediction heads
128
+ self.waypoint_head = nn.Sequential(
129
+ nn.Linear(d_model, 128),
130
+ nn.ReLU(),
131
+ nn.Linear(128, 4), # (x, y, heading, speed)
132
+ )
133
+
134
+ # Confidence / collision probability per waypoint
135
+ self.confidence_head = nn.Sequential(
136
+ nn.Linear(d_model, 64),
137
+ nn.ReLU(),
138
+ nn.Linear(64, 1),
139
+ nn.Sigmoid(),
140
+ )
141
+
142
+ def forward(
143
+ self,
144
+ bev_features: torch.Tensor,
145
+ ego_state: torch.Tensor,
146
+ nav_command: Optional[torch.Tensor] = None,
147
+ ) -> Dict[str, torch.Tensor]:
148
+ """
149
+ Args:
150
+ bev_features: (B, C, H, W) from perception
151
+ ego_state: (B, 6) current ego state [speed, accel, steer, yaw_rate, x, y]
152
+ nav_command: (B,) integer navigation command
153
+ Returns:
154
+ waypoints: (B, num_waypoints, 4) predicted trajectory
155
+ confidence: (B, num_waypoints, 1) per-waypoint confidence
156
+ """
157
+ B = bev_features.shape[0]
158
+ device = bev_features.device
159
+
160
+ # Compress and add positional encoding to BEV
161
+ bev = self.bev_compress(bev_features)
162
+ bev = self.pos_encoding(bev)
163
+
164
+ # Flatten BEV to sequence: (B, H*W, d_model)
165
+ bev_seq = bev.flatten(2).permute(0, 2, 1)
166
+
167
+ # Build trajectory queries
168
+ queries = self.trajectory_queries.unsqueeze(0).expand(B, -1, -1)
169
+
170
+ # Add ego state information to queries
171
+ ego_feat = self.ego_state_embed(ego_state).unsqueeze(1)
172
+ queries = queries + ego_feat
173
+
174
+ # Add navigation command if provided
175
+ if nav_command is not None:
176
+ cmd_feat = self.command_embed(nav_command).unsqueeze(1)
177
+ queries = queries + cmd_feat
178
+
179
+ # Transformer decoding
180
+ decoded = self.transformer_decoder(queries, bev_seq)
181
+
182
+ # Predict waypoints and confidence
183
+ waypoints = self.waypoint_head(decoded) # (B, T, 4)
184
+ confidence = self.confidence_head(decoded) # (B, T, 1)
185
+
186
+ return {
187
+ "waypoints": waypoints,
188
+ "confidence": confidence,
189
+ }
190
+
191
+
192
+ class SafetyChecker(nn.Module):
193
+ """
194
+ Verifies planned trajectories against safety constraints.
195
+ Checks for:
196
+ - Collision with detected objects
197
+ - Lane boundary violations
198
+ - Speed limit violations
199
+ - Minimum following distance
200
+ - Emergency stop conditions
201
+ """
202
+ def __init__(
203
+ self,
204
+ bev_channels: int = 256,
205
+ max_speed_ms: float = 8.94, # 20 mph
206
+ min_following_distance: float = 4.0, # meters
207
+ emergency_decel: float = 8.0, # m/s^2
208
+ ):
209
+ super().__init__()
210
+ self.max_speed_ms = max_speed_ms
211
+ self.min_following_distance = min_following_distance
212
+ self.emergency_decel = emergency_decel
213
+
214
+ # Collision risk estimator
215
+ self.collision_net = nn.Sequential(
216
+ nn.AdaptiveAvgPool2d(8),
217
+ nn.Flatten(),
218
+ nn.Linear(bev_channels * 64, 256),
219
+ nn.ReLU(),
220
+ nn.Linear(256, 64),
221
+ nn.ReLU(),
222
+ nn.Linear(64, 1),
223
+ nn.Sigmoid(),
224
+ )
225
+
226
+ # Emergency brake detector
227
+ self.emergency_detector = nn.Sequential(
228
+ nn.AdaptiveAvgPool2d(4),
229
+ nn.Flatten(),
230
+ nn.Linear(bev_channels * 16, 128),
231
+ nn.ReLU(),
232
+ nn.Linear(128, 2), # [no_emergency, emergency]
233
+ )
234
+
235
+ def forward(
236
+ self,
237
+ bev: torch.Tensor,
238
+ planned_waypoints: torch.Tensor,
239
+ ego_state: torch.Tensor,
240
+ ) -> Dict[str, torch.Tensor]:
241
+ """
242
+ Args:
243
+ bev: (B, C, H, W) BEV features with occupancy info
244
+ planned_waypoints: (B, T, 4) planned trajectory
245
+ ego_state: (B, 6)
246
+ Returns:
247
+ Dict with safety scores and emergency signals
248
+ """
249
+ # Collision risk
250
+ collision_risk = self.collision_net(bev)
251
+
252
+ # Emergency brake
253
+ emergency_logits = self.emergency_detector(bev)
254
+ emergency_prob = F.softmax(emergency_logits, dim=-1)[:, 1:]
255
+
256
+ # Speed constraint check
257
+ planned_speeds = planned_waypoints[:, :, 3] # speed component
258
+ speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True)
259
+
260
+ # Clamp speeds to max
261
+ clamped_waypoints = planned_waypoints.clone()
262
+ clamped_waypoints[:, :, 3] = torch.clamp(
263
+ planned_waypoints[:, :, 3], 0.0, self.max_speed_ms
264
+ )
265
+
266
+ return {
267
+ "collision_risk": collision_risk,
268
+ "emergency_brake": emergency_prob,
269
+ "speed_violation": speed_violation,
270
+ "safe_waypoints": clamped_waypoints,
271
+ }
272
+
273
+
274
+ class PlanningModule(nn.Module):
275
+ """
276
+ Complete planning module.
277
+ Pipeline: BEV → Behavior Prediction → Trajectory Generation → Safety Check
278
+ """
279
+ def __init__(
280
+ self,
281
+ bev_channels: int = 256,
282
+ d_model: int = 256,
283
+ num_waypoints: int = 20,
284
+ max_speed_ms: float = 8.94,
285
+ num_behaviors: int = 10,
286
+ ):
287
+ super().__init__()
288
+
289
+ self.behavior_predictor = BehaviorPredictor(bev_channels, num_behaviors)
290
+ self.trajectory_planner = TrajectoryTransformer(
291
+ bev_channels=bev_channels,
292
+ d_model=d_model,
293
+ num_waypoints=num_waypoints,
294
+ )
295
+ self.safety_checker = SafetyChecker(
296
+ bev_channels=bev_channels,
297
+ max_speed_ms=max_speed_ms,
298
+ )
299
+
300
+ def forward(
301
+ self,
302
+ bev_features: torch.Tensor,
303
+ ego_state: torch.Tensor,
304
+ nav_command: Optional[torch.Tensor] = None,
305
+ ) -> Dict[str, torch.Tensor]:
306
+ """
307
+ Args:
308
+ bev_features: (B, C, H, W)
309
+ ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y]
310
+ nav_command: (B,) high-level navigation command
311
+ Returns:
312
+ Complete planning output including safe trajectory
313
+ """
314
+ # Predict behavior
315
+ behavior_logits = self.behavior_predictor(bev_features)
316
+
317
+ # Generate trajectory
318
+ traj_output = self.trajectory_planner(
319
+ bev_features, ego_state, nav_command
320
+ )
321
+
322
+ # Safety verification
323
+ safety = self.safety_checker(
324
+ bev_features, traj_output["waypoints"], ego_state
325
+ )
326
+
327
+ return {
328
+ "behavior_logits": behavior_logits,
329
+ "raw_waypoints": traj_output["waypoints"],
330
+ "waypoint_confidence": traj_output["confidence"],
331
+ "safe_waypoints": safety["safe_waypoints"],
332
+ "collision_risk": safety["collision_risk"],
333
+ "emergency_brake": safety["emergency_brake"],
334
+ "speed_violation": safety["speed_violation"],
335
+ }