Fix fsd_model/planning.py for training (autograd + in-place ops)
Browse files- fsd_model/planning.py +6 -5
fsd_model/planning.py
CHANGED
|
@@ -257,11 +257,12 @@ class SafetyChecker(nn.Module):
|
|
| 257 |
planned_speeds = planned_waypoints[:, :, 3] # speed component
|
| 258 |
speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True)
|
| 259 |
|
| 260 |
-
# Clamp speeds to max
|
| 261 |
-
|
| 262 |
-
clamped_waypoints
|
| 263 |
-
planned_waypoints[:, :, 3],
|
| 264 |
-
|
|
|
|
| 265 |
|
| 266 |
return {
|
| 267 |
"collision_risk": collision_risk,
|
|
|
|
| 257 |
planned_speeds = planned_waypoints[:, :, 3] # speed component
|
| 258 |
speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True)
|
| 259 |
|
| 260 |
+
# Clamp speeds to max (no in-place ops for autograd)
|
| 261 |
+
clamped_speeds = torch.clamp(planned_waypoints[:, :, 3], 0.0, self.max_speed_ms)
|
| 262 |
+
clamped_waypoints = torch.cat([
|
| 263 |
+
planned_waypoints[:, :, :3],
|
| 264 |
+
clamped_speeds.unsqueeze(-1),
|
| 265 |
+
], dim=-1)
|
| 266 |
|
| 267 |
return {
|
| 268 |
"collision_risk": collision_risk,
|