| """ |
| Chain-of-Thought (CoT) Reasoning Module for Autonomous Driving Safety. |
| |
| Inspired by Alpamayo-R1's Chain-of-Causation and AgentThink's structured reasoning. |
| Implements a multi-stage reasoning pipeline: |
| |
| Stage 1: Scene Narration β "What do I see?" |
| Encodes BEV + perception outputs into a structured scene description vector. |
| Identifies all actors, road topology, traffic signals, weather. |
| |
| Stage 2: Risk Assessment β "What could go wrong?" |
| For each actor/hazard, predicts threat level, time-to-collision (TTC), |
| probability of incursion into ego's planned path. |
| |
| Stage 3: Causal Reasoning β "Why should I act?" |
| Chains scene evidence β risk β required behavior. |
| Produces an interpretable reasoning trace (vector + decodable tokens). |
| |
| Stage 4: Decision Gate β "What should I do?" |
| Outputs a safety-verified action decision that overrides the base planner |
| when the reasoning chain identifies danger the planner missed. |
| |
| The CoT module sits BETWEEN perception and planning, enriching the BEV |
| features with explicit safety reasoning before trajectory generation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Optional, Tuple, List |
| import math |
|
|
|
|
| |
| |
| |
|
|
| class SceneNarrationEncoder(nn.Module): |
| """ |
| Encodes the driving scene into a structured representation: |
| - Actor features (detected objects with class, velocity, distance) |
| - Road topology (lanes, intersections, merges) |
| - Traffic state (signals, signs, right-of-way) |
| - Environmental conditions (implicit from camera features) |
| |
| Produces a scene token sequence for downstream reasoning. |
| """ |
|
|
| def __init__( |
| self, |
| bev_channels: int = 256, |
| num_actor_queries: int = 64, |
| num_road_queries: int = 32, |
| d_model: int = 256, |
| nhead: int = 8, |
| num_layers: int = 3, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.num_actor_queries = num_actor_queries |
| self.num_road_queries = num_road_queries |
|
|
| |
| self.bev_proj = nn.Sequential( |
| nn.Conv2d(bev_channels, d_model, 1), |
| nn.BatchNorm2d(d_model), |
| nn.GELU(), |
| ) |
|
|
| |
| self.actor_queries = nn.Parameter(torch.randn(num_actor_queries, d_model)) |
| self.road_queries = nn.Parameter(torch.randn(num_road_queries, d_model)) |
|
|
| |
| actor_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=0.1, batch_first=True, activation="gelu", |
| ) |
| self.actor_decoder = nn.TransformerDecoder(actor_layer, num_layers=num_layers) |
|
|
| road_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=0.1, batch_first=True, activation="gelu", |
| ) |
| self.road_decoder = nn.TransformerDecoder(road_layer, num_layers=num_layers) |
|
|
| |
| self.actor_class_head = nn.Linear(d_model, 10) |
| self.actor_exist_head = nn.Linear(d_model, 1) |
| self.actor_dist_head = nn.Linear(d_model, 1) |
| self.actor_vel_head = nn.Linear(d_model, 2) |
| self.actor_threat_head = nn.Linear(d_model, 1) |
|
|
| |
| self.road_type_head = nn.Linear(d_model, 7) |
| self.road_state_head = nn.Linear(d_model, 4) |
|
|
| |
| self.scene_summary = nn.Sequential( |
| nn.Linear(d_model * 2, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| ) |
|
|
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| B = bev_features.shape[0] |
| device = bev_features.device |
|
|
| |
| bev = self.bev_proj(bev_features) |
| bev_seq = bev.flatten(2).permute(0, 2, 1) |
|
|
| |
| aq = self.actor_queries.unsqueeze(0).expand(B, -1, -1) |
| actor_tokens = self.actor_decoder(aq, bev_seq) |
|
|
| |
| rq = self.road_queries.unsqueeze(0).expand(B, -1, -1) |
| road_tokens = self.road_decoder(rq, bev_seq) |
|
|
| |
| actor_class = self.actor_class_head(actor_tokens) |
| actor_exist = torch.sigmoid(self.actor_exist_head(actor_tokens)) |
| actor_dist = F.relu(self.actor_dist_head(actor_tokens)) |
| actor_vel = self.actor_vel_head(actor_tokens) |
| actor_threat = torch.sigmoid(self.actor_threat_head(actor_tokens)) |
|
|
| road_type = self.road_type_head(road_tokens) |
| road_state = self.road_state_head(road_tokens) |
|
|
| |
| actor_pool = (actor_tokens * actor_exist).sum(dim=1) / actor_exist.sum(dim=1).clamp(min=1) |
| road_pool = road_tokens.mean(dim=1) |
| scene_token = self.scene_summary(torch.cat([actor_pool, road_pool], dim=-1)) |
|
|
| return { |
| "actor_tokens": actor_tokens, |
| "actor_class": actor_class, |
| "actor_exist": actor_exist, |
| "actor_distance": actor_dist, |
| "actor_velocity": actor_vel, |
| "actor_threat": actor_threat, |
| "road_tokens": road_tokens, |
| "road_type": road_type, |
| "road_signal_state": road_state, |
| "scene_token": scene_token, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class RiskAssessmentModule(nn.Module): |
| """ |
| For each detected actor, computes: |
| - Time-to-collision (TTC) with ego's projected path |
| - Collision probability over planning horizon |
| - Risk category (none / low / medium / high / critical) |
| |
| Also computes aggregate scene risk and identifies the |
| single most dangerous actor (worst-case reasoning). |
| """ |
|
|
| RISK_LEVELS = ["none", "low", "medium", "high", "critical"] |
|
|
| def __init__(self, d_model: int = 256, num_risk_levels: int = 5): |
| super().__init__() |
| self.d_model = d_model |
| self.num_risk_levels = num_risk_levels |
|
|
| |
| self.risk_mlp = nn.Sequential( |
| nn.Linear(d_model + 1 + 2 + 1, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| ) |
|
|
| |
| self.ttc_head = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Softplus(), |
| ) |
|
|
| |
| self.collision_prob_head = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.risk_level_head = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, num_risk_levels), |
| ) |
|
|
| |
| self.actor_self_attn = nn.MultiheadAttention( |
| d_model, num_heads=8, batch_first=True, dropout=0.1 |
| ) |
| self.attn_norm = nn.LayerNorm(d_model) |
|
|
| |
| self.scene_risk = nn.Sequential( |
| nn.Linear(d_model, 128), |
| nn.GELU(), |
| nn.Linear(128, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward( |
| self, |
| actor_tokens: torch.Tensor, |
| actor_exist: torch.Tensor, |
| actor_distance: torch.Tensor, |
| actor_velocity: torch.Tensor, |
| actor_threat: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| B, Na, d = actor_tokens.shape |
|
|
| |
| actor_input = torch.cat([ |
| actor_tokens, actor_distance, actor_velocity, actor_threat |
| ], dim=-1) |
|
|
| risk_feat = self.risk_mlp(actor_input) |
|
|
| |
| mask = (actor_exist.squeeze(-1) < 0.3) |
| attn_out, _ = self.actor_self_attn( |
| risk_feat, risk_feat, risk_feat, |
| key_padding_mask=mask, |
| ) |
| risk_feat = self.attn_norm(risk_feat + attn_out) |
|
|
| |
| ttc = self.ttc_head(risk_feat) |
| collision_prob = self.collision_prob_head(risk_feat) |
| risk_level = self.risk_level_head(risk_feat) |
|
|
| |
| weighted_risk = collision_prob.squeeze(-1) * actor_exist.squeeze(-1) |
| worst_idx = weighted_risk.argmax(dim=1) |
| worst_actor = risk_feat[torch.arange(B), worst_idx] |
|
|
| |
| exist_weight = actor_exist / actor_exist.sum(dim=1, keepdim=True).clamp(min=1) |
| pooled = (risk_feat * exist_weight).sum(dim=1) |
| agg_risk = self.scene_risk(pooled) |
|
|
| return { |
| "risk_features": risk_feat, |
| "ttc": ttc, |
| "collision_probability": collision_prob, |
| "risk_level_logits": risk_level, |
| "worst_actor_feature": worst_actor, |
| "worst_actor_idx": worst_idx, |
| "aggregate_scene_risk": agg_risk, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class CausalReasoningChain(nn.Module): |
| """ |
| Implements structured causal reasoning: |
| |
| Evidence tokens (scene + risk) |
| β Transformer reasoning layers |
| β Causal conclusion tokens |
| |
| The reasoning chain is autoregressive across 4 "thought steps": |
| 1. Situation assessment (what's happening) |
| 2. Hazard identification (what's dangerous) |
| 3. Action justification (why act this way) |
| 4. Action decision (what to do) |
| |
| Each step conditions on all previous steps, enabling the model |
| to build up a coherent chain of reasoning. |
| """ |
|
|
| NUM_THOUGHT_STEPS = 4 |
|
|
| def __init__( |
| self, |
| d_model: int = 256, |
| nhead: int = 8, |
| num_layers: int = 4, |
| num_behaviors: int = 10, |
| ): |
| super().__init__() |
| self.d_model = d_model |
|
|
| |
| self.thought_embeddings = nn.Parameter( |
| torch.randn(self.NUM_THOUGHT_STEPS, d_model) |
| ) |
|
|
| |
| reason_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=0.1, batch_first=True, activation="gelu", |
| ) |
| self.reasoning_transformer = nn.TransformerEncoder( |
| reason_layer, num_layers=num_layers, |
| ) |
|
|
| |
| cross_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=0.1, batch_first=True, activation="gelu", |
| ) |
| self.evidence_cross_attn = nn.TransformerDecoder( |
| cross_layer, num_layers=2, |
| ) |
|
|
| |
| |
| self.situation_head = nn.Linear(d_model, d_model) |
| |
| self.hazard_head = nn.Linear(d_model, d_model) |
| |
| self.justification_head = nn.Linear(d_model, d_model) |
| |
| self.action_head = nn.Linear(d_model, num_behaviors) |
|
|
| |
| self.override_confidence = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.urgency_head = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| def _causal_mask(self, sz: int, device: torch.device) -> torch.Tensor: |
| """Upper-triangular causal mask for autoregressive reasoning.""" |
| return torch.triu(torch.ones(sz, sz, device=device) * float('-inf'), diagonal=1) |
|
|
| def forward( |
| self, |
| scene_token: torch.Tensor, |
| risk_features: torch.Tensor, |
| worst_actor_feature: torch.Tensor, |
| aggregate_risk: torch.Tensor, |
| ego_state: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| B = scene_token.shape[0] |
| device = scene_token.device |
|
|
| |
| ego_embed = F.gelu(nn.Linear(6, self.d_model).to(device)(ego_state)) |
| risk_pool = risk_features.mean(dim=1) |
|
|
| evidence = torch.stack([scene_token, worst_actor_feature, ego_embed, risk_pool], dim=1) |
| |
|
|
| |
| thoughts = self.thought_embeddings.unsqueeze(0).expand(B, -1, -1) |
|
|
| |
| thoughts = self.evidence_cross_attn(thoughts, evidence) |
|
|
| |
| mask = self._causal_mask(self.NUM_THOUGHT_STEPS, device) |
| thoughts = self.reasoning_transformer(thoughts, mask=mask) |
|
|
| |
| situation = self.situation_head(thoughts[:, 0]) |
| hazard = self.hazard_head(thoughts[:, 1]) |
| justification = self.justification_head(thoughts[:, 2]) |
|
|
| action_logits = self.action_head(thoughts[:, 3]) |
| override_conf = self.override_confidence(thoughts[:, 3]) |
| urgency = self.urgency_head(thoughts[:, 3]) |
|
|
| |
| reasoning_trace = thoughts |
|
|
| return { |
| "situation_embedding": situation, |
| "hazard_embedding": hazard, |
| "justification_embedding": justification, |
| "cot_action_logits": action_logits, |
| "override_confidence": override_conf, |
| "urgency": urgency, |
| "reasoning_trace": reasoning_trace, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class SafetyDecisionGate(nn.Module): |
| """ |
| Final gate that merges base planner output with CoT reasoning. |
| |
| If CoT reasoning has high override confidence AND urgency, |
| the gate replaces the planner's trajectory with a safe fallback. |
| |
| This implements a "safety envelope" β the CoT reasoning can |
| only make driving MORE conservative, never more aggressive. |
| |
| Fallback behaviors: |
| - emergency_stop: full brake |
| - slow_down: reduce speed proportional to risk |
| - yield: stop at yield line |
| - swerve_avoid: modify lateral trajectory |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int = 256, |
| num_waypoints: int = 20, |
| max_speed_ms: float = 8.94, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.num_waypoints = num_waypoints |
| self.max_speed_ms = max_speed_ms |
|
|
| |
| self.traj_modifier = nn.Sequential( |
| nn.Linear(d_model + 4 * num_waypoints + 1 + 1, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, num_waypoints * 4), |
| ) |
|
|
| |
| self.blend_weight = nn.Sequential( |
| nn.Linear(d_model + 1 + 1, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.safety_score = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| def forward( |
| self, |
| planner_waypoints: torch.Tensor, |
| justification_embedding: torch.Tensor, |
| override_confidence: torch.Tensor, |
| urgency: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| B = planner_waypoints.shape[0] |
|
|
| wp_flat = planner_waypoints.reshape(B, -1) |
|
|
| |
| blend_input = torch.cat([justification_embedding, override_confidence, urgency], dim=-1) |
| alpha = self.blend_weight(blend_input) |
|
|
| |
| |
| alpha = alpha * urgency |
|
|
| |
| mod_input = torch.cat([justification_embedding, wp_flat, override_confidence, urgency], dim=-1) |
| cot_wp_flat = self.traj_modifier(mod_input) |
| cot_waypoints = cot_wp_flat.reshape(B, self.num_waypoints, 4) |
|
|
| |
| planner_speeds = planner_waypoints[:, :, 3] |
| cot_speeds = cot_waypoints[:, :, 3] |
| safe_speeds = torch.min(planner_speeds, F.relu(cot_speeds)) |
| safe_speeds = torch.clamp(safe_speeds, 0.0, self.max_speed_ms) |
| |
| |
| cot_waypoints = torch.cat([ |
| cot_waypoints[:, :, :3], |
| safe_speeds.unsqueeze(-1), |
| ], dim=-1) |
|
|
| |
| alpha_expanded = alpha.unsqueeze(-1) |
| gated_waypoints = (1 - alpha_expanded) * planner_waypoints + alpha_expanded * cot_waypoints |
|
|
| |
| gated_speeds = torch.min(gated_waypoints[:, :, 3], planner_waypoints[:, :, 3]) |
| gated_speeds = torch.clamp(gated_speeds, 0.0, self.max_speed_ms) |
| gated_waypoints = torch.cat([ |
| gated_waypoints[:, :, :3], |
| gated_speeds.unsqueeze(-1), |
| ], dim=-1) |
|
|
| |
| safety = self.safety_score(justification_embedding) |
|
|
| return { |
| "gated_waypoints": gated_waypoints, |
| "cot_waypoints": cot_waypoints, |
| "blend_alpha": alpha, |
| "post_gate_safety_score": safety, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class ChainOfThoughtReasoning(nn.Module): |
| """ |
| Complete Chain-of-Thought reasoning pipeline for safe autonomous driving. |
| |
| Pipeline: |
| BEV features + ego state |
| β Scene Narration (what's around me) |
| β Risk Assessment (what's dangerous) |
| β Causal Reasoning (why act this way) |
| β Safety Decision Gate (override if needed) |
| |
| Produces: |
| 1. Enriched BEV features (safety-aware) |
| 2. Safety-gated waypoints |
| 3. Interpretable reasoning trace |
| 4. Per-actor risk breakdown |
| """ |
|
|
| def __init__( |
| self, |
| bev_channels: int = 256, |
| d_model: int = 256, |
| num_actor_queries: int = 64, |
| num_road_queries: int = 32, |
| num_waypoints: int = 20, |
| num_behaviors: int = 10, |
| max_speed_ms: float = 8.94, |
| ): |
| super().__init__() |
| self.d_model = d_model |
|
|
| |
| self.scene_narrator = SceneNarrationEncoder( |
| bev_channels=bev_channels, |
| num_actor_queries=num_actor_queries, |
| num_road_queries=num_road_queries, |
| d_model=d_model, |
| ) |
|
|
| |
| self.risk_assessor = RiskAssessmentModule(d_model=d_model) |
|
|
| |
| self.causal_reasoner = CausalReasoningChain( |
| d_model=d_model, |
| num_behaviors=num_behaviors, |
| ) |
|
|
| |
| self.safety_gate = SafetyDecisionGate( |
| d_model=d_model, |
| num_waypoints=num_waypoints, |
| max_speed_ms=max_speed_ms, |
| ) |
|
|
| |
| self.bev_enrichment = nn.Sequential( |
| nn.Conv2d(bev_channels + d_model, bev_channels, 1), |
| nn.BatchNorm2d(bev_channels), |
| nn.GELU(), |
| nn.Conv2d(bev_channels, bev_channels, 3, padding=1), |
| nn.BatchNorm2d(bev_channels), |
| nn.GELU(), |
| ) |
|
|
| |
| self.ego_proj = nn.Sequential( |
| nn.Linear(6, d_model), |
| nn.GELU(), |
| ) |
|
|
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| ego_state: torch.Tensor, |
| planner_waypoints: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| B, C, H, W = bev_features.shape |
| device = bev_features.device |
|
|
| |
| scene = self.scene_narrator(bev_features) |
|
|
| |
| risk = self.risk_assessor( |
| actor_tokens=scene["actor_tokens"], |
| actor_exist=scene["actor_exist"], |
| actor_distance=scene["actor_distance"], |
| actor_velocity=scene["actor_velocity"], |
| actor_threat=scene["actor_threat"], |
| ) |
|
|
| |
| ego_embed = self.ego_proj(ego_state) |
| |
| reason = self._run_causal_reasoning( |
| scene["scene_token"], |
| risk["risk_features"], |
| risk["worst_actor_feature"], |
| risk["aggregate_scene_risk"], |
| ego_embed, |
| ) |
|
|
| |
| reasoning_map = reason["justification_embedding"].unsqueeze(-1).unsqueeze(-1) |
| reasoning_map = reasoning_map.expand(-1, -1, H, W) |
| enriched_bev = self.bev_enrichment( |
| torch.cat([bev_features, reasoning_map], dim=1) |
| ) |
| enriched_bev = enriched_bev + bev_features |
|
|
| |
| gate_output = {} |
| if planner_waypoints is not None: |
| gate_output = self.safety_gate( |
| planner_waypoints=planner_waypoints, |
| justification_embedding=reason["justification_embedding"], |
| override_confidence=reason["override_confidence"], |
| urgency=reason["urgency"], |
| ) |
|
|
| |
| output = { |
| "enriched_bev": enriched_bev, |
| |
| "cot/actor_class": scene["actor_class"], |
| "cot/actor_exist": scene["actor_exist"], |
| "cot/actor_distance": scene["actor_distance"], |
| "cot/actor_velocity": scene["actor_velocity"], |
| |
| "cot/ttc": risk["ttc"], |
| "cot/collision_probability": risk["collision_probability"], |
| "cot/risk_level_logits": risk["risk_level_logits"], |
| "cot/aggregate_risk": risk["aggregate_scene_risk"], |
| "cot/worst_actor_idx": risk["worst_actor_idx"], |
| |
| "cot/action_logits": reason["cot_action_logits"], |
| "cot/override_confidence": reason["override_confidence"], |
| "cot/urgency": reason["urgency"], |
| "cot/reasoning_trace": reason["reasoning_trace"], |
| } |
| output.update({f"cot/{k}": v for k, v in gate_output.items()}) |
|
|
| return output |
|
|
| def _run_causal_reasoning( |
| self, scene_token, risk_features, worst_actor, agg_risk, ego_embed, |
| ): |
| """Run causal reasoning with pre-computed ego embedding.""" |
| B = scene_token.shape[0] |
| device = scene_token.device |
| d = self.d_model |
|
|
| risk_pool = risk_features.mean(dim=1) |
| evidence = torch.stack([scene_token, worst_actor, ego_embed, risk_pool], dim=1) |
|
|
| cr = self.causal_reasoner |
| thoughts = cr.thought_embeddings.unsqueeze(0).expand(B, -1, -1) |
| thoughts = cr.evidence_cross_attn(thoughts, evidence) |
| mask = cr._causal_mask(cr.NUM_THOUGHT_STEPS, device) |
| thoughts = cr.reasoning_transformer(thoughts, mask=mask) |
|
|
| situation = cr.situation_head(thoughts[:, 0]) |
| hazard = cr.hazard_head(thoughts[:, 1]) |
| justification = cr.justification_head(thoughts[:, 2]) |
| action_logits = cr.action_head(thoughts[:, 3]) |
| override_conf = cr.override_confidence(thoughts[:, 3]) |
| urgency = cr.urgency_head(thoughts[:, 3]) |
|
|
| return { |
| "situation_embedding": situation, |
| "hazard_embedding": hazard, |
| "justification_embedding": justification, |
| "cot_action_logits": action_logits, |
| "override_confidence": override_conf, |
| "urgency": urgency, |
| "reasoning_trace": thoughts, |
| } |
|
|