Fix fsd_model/cot_reasoning.py for training (autograd + in-place ops)
Browse files- fsd_model/cot_reasoning.py +13 -9
fsd_model/cot_reasoning.py
CHANGED
|
@@ -491,21 +491,25 @@ class SafetyDecisionGate(nn.Module):
|
|
| 491 |
planner_speeds = planner_waypoints[:, :, 3]
|
| 492 |
cot_speeds = cot_waypoints[:, :, 3]
|
| 493 |
safe_speeds = torch.min(planner_speeds, F.relu(cot_speeds))
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
|
|
|
| 499 |
|
| 500 |
# Blend: output = (1-alpha)*planner + alpha*cot
|
| 501 |
alpha_expanded = alpha.unsqueeze(-1) # (B, 1, 1)
|
| 502 |
gated_waypoints = (1 - alpha_expanded) * planner_waypoints + alpha_expanded * cot_waypoints
|
| 503 |
|
| 504 |
# Ensure gated speeds never exceed planner speeds (monotonic safety)
|
| 505 |
-
gated_waypoints[:, :, 3]
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
| 509 |
|
| 510 |
# Post-gate safety score
|
| 511 |
safety = self.safety_score(justification_embedding)
|
|
|
|
| 491 |
planner_speeds = planner_waypoints[:, :, 3]
|
| 492 |
cot_speeds = cot_waypoints[:, :, 3]
|
| 493 |
safe_speeds = torch.min(planner_speeds, F.relu(cot_speeds))
|
| 494 |
+
safe_speeds = torch.clamp(safe_speeds, 0.0, self.max_speed_ms)
|
| 495 |
+
|
| 496 |
+
# Build cot_waypoints without in-place ops
|
| 497 |
+
cot_waypoints = torch.cat([
|
| 498 |
+
cot_waypoints[:, :, :3],
|
| 499 |
+
safe_speeds.unsqueeze(-1),
|
| 500 |
+
], dim=-1)
|
| 501 |
|
| 502 |
# Blend: output = (1-alpha)*planner + alpha*cot
|
| 503 |
alpha_expanded = alpha.unsqueeze(-1) # (B, 1, 1)
|
| 504 |
gated_waypoints = (1 - alpha_expanded) * planner_waypoints + alpha_expanded * cot_waypoints
|
| 505 |
|
| 506 |
# Ensure gated speeds never exceed planner speeds (monotonic safety)
|
| 507 |
+
gated_speeds = torch.min(gated_waypoints[:, :, 3], planner_waypoints[:, :, 3])
|
| 508 |
+
gated_speeds = torch.clamp(gated_speeds, 0.0, self.max_speed_ms)
|
| 509 |
+
gated_waypoints = torch.cat([
|
| 510 |
+
gated_waypoints[:, :, :3],
|
| 511 |
+
gated_speeds.unsqueeze(-1),
|
| 512 |
+
], dim=-1)
|
| 513 |
|
| 514 |
# Post-gate safety score
|
| 515 |
safety = self.safety_score(justification_embedding)
|