Reality123b commited on
Commit
efdd5e3
·
verified ·
1 Parent(s): 2a2385f

Fix fsd_model/cot_reasoning.py for training (autograd + in-place ops)

Browse files
Files changed (1) hide show
  1. fsd_model/cot_reasoning.py +13 -9
fsd_model/cot_reasoning.py CHANGED
@@ -491,21 +491,25 @@ class SafetyDecisionGate(nn.Module):
491
  planner_speeds = planner_waypoints[:, :, 3]
492
  cot_speeds = cot_waypoints[:, :, 3]
493
  safe_speeds = torch.min(planner_speeds, F.relu(cot_speeds))
494
- cot_waypoints = cot_waypoints.clone()
495
- cot_waypoints[:, :, 3] = safe_speeds
496
-
497
- # Clamp all speeds
498
- cot_waypoints[:, :, 3] = torch.clamp(cot_waypoints[:, :, 3], 0.0, self.max_speed_ms)
 
 
499
 
500
  # Blend: output = (1-alpha)*planner + alpha*cot
501
  alpha_expanded = alpha.unsqueeze(-1) # (B, 1, 1)
502
  gated_waypoints = (1 - alpha_expanded) * planner_waypoints + alpha_expanded * cot_waypoints
503
 
504
  # Ensure gated speeds never exceed planner speeds (monotonic safety)
505
- gated_waypoints[:, :, 3] = torch.min(
506
- gated_waypoints[:, :, 3], planner_waypoints[:, :, 3]
507
- )
508
- gated_waypoints[:, :, 3] = torch.clamp(gated_waypoints[:, :, 3], 0.0, self.max_speed_ms)
 
 
509
 
510
  # Post-gate safety score
511
  safety = self.safety_score(justification_embedding)
 
491
  planner_speeds = planner_waypoints[:, :, 3]
492
  cot_speeds = cot_waypoints[:, :, 3]
493
  safe_speeds = torch.min(planner_speeds, F.relu(cot_speeds))
494
+ safe_speeds = torch.clamp(safe_speeds, 0.0, self.max_speed_ms)
495
+
496
+ # Build cot_waypoints without in-place ops
497
+ cot_waypoints = torch.cat([
498
+ cot_waypoints[:, :, :3],
499
+ safe_speeds.unsqueeze(-1),
500
+ ], dim=-1)
501
 
502
  # Blend: output = (1-alpha)*planner + alpha*cot
503
  alpha_expanded = alpha.unsqueeze(-1) # (B, 1, 1)
504
  gated_waypoints = (1 - alpha_expanded) * planner_waypoints + alpha_expanded * cot_waypoints
505
 
506
  # Ensure gated speeds never exceed planner speeds (monotonic safety)
507
+ gated_speeds = torch.min(gated_waypoints[:, :, 3], planner_waypoints[:, :, 3])
508
+ gated_speeds = torch.clamp(gated_speeds, 0.0, self.max_speed_ms)
509
+ gated_waypoints = torch.cat([
510
+ gated_waypoints[:, :, :3],
511
+ gated_speeds.unsqueeze(-1),
512
+ ], dim=-1)
513
 
514
  # Post-gate safety score
515
  safety = self.safety_score(justification_embedding)