Spaces:
Running on Zero
Running on Zero
| """ | |
| AI Accident Analysis — Fault Deducer | |
| Determines probable cause and fault assignment based on matched | |
| violations, their severity weights, and physical evidence. | |
| """ | |
| import json | |
| from typing import List, Dict, Optional, Tuple | |
| from backend.app.db.database import db | |
| from backend.app.rules.rule_loader import rule_loader | |
| from backend.app.utils.logger import get_logger | |
| logger = get_logger("fault_deducer") | |
| # Severity multipliers for fault scoring | |
| SEVERITY_MULTIPLIER = { | |
| "CRITICAL": 3.0, | |
| "HIGH": 2.0, | |
| "MEDIUM": 1.0, | |
| "LOW": 0.5, | |
| } | |
| class FaultDeducer: | |
| """ | |
| Determines probable cause and fault assignment. | |
| Scoring algorithm: | |
| fault_score(party) = Σ (violation.confidence × violation.fault_weight × severity_multiplier) | |
| The party with the highest fault score is assigned primary fault. | |
| Fault percentages are computed as relative scores. | |
| """ | |
| async def deduce_fault(self, case_id: int) -> Dict: | |
| """ | |
| Analyze violations and evidence to determine fault. | |
| Returns fault analysis and saves to database. | |
| """ | |
| logger.info(f"Starting fault deduction for case {case_id}") | |
| violations = await db.get_violations_by_case(case_id) | |
| parties = await db.get_parties_by_case(case_id) | |
| analyses = await db.get_analyses_by_case(case_id) | |
| if not violations: | |
| result = self._no_violations_result(case_id, parties) | |
| await self._save_result(case_id, result) | |
| return result | |
| if not parties: | |
| result = self._no_parties_result(case_id, violations) | |
| await self._save_result(case_id, result) | |
| return result | |
| # Calculate fault scores per party | |
| party_scores = self._calculate_fault_scores(violations, parties) | |
| # Determine fault distribution (percentages) | |
| fault_distribution = self._compute_distribution(party_scores, parties) | |
| # Find primary fault party | |
| primary_party_id, primary_score = max( | |
| party_scores.items(), key=lambda x: x[1] | |
| ) if party_scores else (None, 0) | |
| # Generate probable cause narrative | |
| probable_cause = self._generate_probable_cause( | |
| violations, parties, party_scores, analyses | |
| ) | |
| # Generate summary | |
| analysis_summary = self._generate_summary( | |
| violations, parties, fault_distribution, primary_party_id | |
| ) | |
| # Overall confidence = weighted average of violation confidences | |
| total_confidence_weight = sum( | |
| v["confidence"] * SEVERITY_MULTIPLIER.get(v.get("severity", "MEDIUM"), 1.0) | |
| for v in violations | |
| ) | |
| total_weight = sum( | |
| SEVERITY_MULTIPLIER.get(v.get("severity", "MEDIUM"), 1.0) | |
| for v in violations | |
| ) | |
| overall_confidence = round( | |
| total_confidence_weight / max(total_weight, 1), 3 | |
| ) | |
| result = { | |
| "case_id": case_id, | |
| "primary_fault_party_id": primary_party_id, | |
| "primary_fault_party_label": next( | |
| (p["label"] for p in parties if p["id"] == primary_party_id), None | |
| ), | |
| "fault_distribution": fault_distribution, | |
| "probable_cause": probable_cause, | |
| "overall_confidence": overall_confidence, | |
| "analysis_summary": analysis_summary, | |
| "violation_count": len(violations), | |
| "party_scores": { | |
| next((p["label"] for p in parties if p["id"] == pid), f"Party {pid}"): round(score, 3) | |
| for pid, score in party_scores.items() | |
| }, | |
| } | |
| await self._save_result(case_id, result) | |
| logger.info( | |
| f"Case {case_id} fault analysis: primary fault = " | |
| f"{result['primary_fault_party_label']}, " | |
| f"confidence = {overall_confidence}, " | |
| f"{len(violations)} violations" | |
| ) | |
| return result | |
| def _calculate_fault_scores( | |
| self, violations: List[dict], parties: List[dict] | |
| ) -> Dict[int, float]: | |
| """ | |
| Calculate fault scores per party. | |
| Score = Σ (confidence × fault_weight × severity_multiplier) | |
| """ | |
| scores = {p["id"]: 0.0 for p in parties} | |
| for violation in violations: | |
| party_id = violation.get("party_id") | |
| if party_id is None or party_id not in scores: | |
| continue | |
| rule = rule_loader.get_rule_by_id(violation["rule_id"]) | |
| fault_weight = rule.fault_weight if rule else 0.5 | |
| severity = violation.get("severity", "MEDIUM") | |
| multiplier = SEVERITY_MULTIPLIER.get(severity, 1.0) | |
| score = violation["confidence"] * fault_weight * multiplier | |
| scores[party_id] += score | |
| return scores | |
| def _compute_distribution( | |
| self, party_scores: Dict[int, float], parties: List[dict] | |
| ) -> Dict[str, float]: | |
| """Convert raw scores to fault percentage distribution.""" | |
| total = sum(party_scores.values()) | |
| if total == 0: | |
| # Equal distribution | |
| n = len(parties) | |
| return { | |
| p["label"]: round(100.0 / n, 1) for p in parties | |
| } | |
| distribution = {} | |
| for party in parties: | |
| score = party_scores.get(party["id"], 0) | |
| pct = round((score / total) * 100, 1) | |
| distribution[party["label"]] = pct | |
| return distribution | |
| def _generate_probable_cause( | |
| self, violations: List[dict], parties: List[dict], | |
| party_scores: Dict[int, float], analyses: List[dict] | |
| ) -> str: | |
| """Generate a human-readable probable cause narrative.""" | |
| if not violations: | |
| return "Insufficient evidence to determine probable cause." | |
| # Group violations by party | |
| party_violations = {} | |
| for v in violations: | |
| party_id = v.get("party_id") | |
| party_label = v.get("party_label", "Unknown") | |
| if party_label not in party_violations: | |
| party_violations[party_label] = [] | |
| party_violations[party_label].append(v) | |
| # Build narrative | |
| parts = ["Based on the analysis of accident scene photographs:"] | |
| for party_label, p_violations in party_violations.items(): | |
| critical = [v for v in p_violations if v.get("severity") == "CRITICAL"] | |
| high = [v for v in p_violations if v.get("severity") == "HIGH"] | |
| other = [v for v in p_violations | |
| if v.get("severity") not in ("CRITICAL", "HIGH")] | |
| violation_desc = [] | |
| for v in (critical + high + other)[:5]: # Top 5 violations | |
| violation_desc.append( | |
| f"{v['rule_title']} ({v.get('severity', 'MEDIUM')}, " | |
| f"confidence: {v['confidence']:.0%})" | |
| ) | |
| parts.append( | |
| f"\n{party_label} was found to have the following violations: " | |
| + "; ".join(violation_desc) + "." | |
| ) | |
| # Add conclusion | |
| if party_scores: | |
| max_party_id = max(party_scores, key=party_scores.get) | |
| max_label = next( | |
| (p["label"] for p in parties if p["id"] == max_party_id), | |
| "Unknown" | |
| ) | |
| parts.append( | |
| f"\nBased on the severity and number of violations, " | |
| f"{max_label} is assessed as the primary contributing party." | |
| ) | |
| return " ".join(parts) | |
| def _generate_summary( | |
| self, violations: List[dict], parties: List[dict], | |
| distribution: Dict[str, float], primary_party_id: Optional[int] | |
| ) -> str: | |
| """Generate a concise summary of the fault analysis.""" | |
| primary_label = next( | |
| (p["label"] for p in parties if p["id"] == primary_party_id), | |
| "Unknown" | |
| ) | |
| critical_count = sum(1 for v in violations if v.get("severity") == "CRITICAL") | |
| high_count = sum(1 for v in violations if v.get("severity") == "HIGH") | |
| dist_str = ", ".join( | |
| f"{label}: {pct}%" for label, pct in distribution.items() | |
| ) | |
| return ( | |
| f"Analysis identified {len(violations)} traffic violation(s) " | |
| f"({critical_count} critical, {high_count} high severity). " | |
| f"Primary fault assigned to {primary_label}. " | |
| f"Fault distribution: {dist_str}." | |
| ) | |
| def _no_violations_result(self, case_id: int, parties: List[dict]) -> Dict: | |
| """Result when no violations are found.""" | |
| return { | |
| "case_id": case_id, | |
| "primary_fault_party_id": None, | |
| "primary_fault_party_label": None, | |
| "fault_distribution": { | |
| p["label"]: round(100.0 / max(len(parties), 1), 1) | |
| for p in parties | |
| } if parties else {}, | |
| "probable_cause": ( | |
| "No clear traffic violations were identified from the available " | |
| "photographs. Additional evidence or investigation may be required." | |
| ), | |
| "overall_confidence": 0.0, | |
| "analysis_summary": "No violations detected. Manual review recommended.", | |
| "violation_count": 0, | |
| "party_scores": {}, | |
| } | |
| def _no_parties_result(self, case_id: int, violations: List[dict]) -> Dict: | |
| """Result when no parties are identified.""" | |
| return { | |
| "case_id": case_id, | |
| "primary_fault_party_id": None, | |
| "primary_fault_party_label": None, | |
| "fault_distribution": {}, | |
| "probable_cause": ( | |
| f"{len(violations)} violation(s) detected but no parties could be " | |
| "identified from the photographs. Manual party identification required." | |
| ), | |
| "overall_confidence": 0.3, | |
| "analysis_summary": "Violations detected but parties not identifiable.", | |
| "violation_count": len(violations), | |
| "party_scores": {}, | |
| } | |
| async def _save_result(self, case_id: int, result: Dict): | |
| """Save fault analysis to database.""" | |
| try: | |
| await db.save_fault_analysis( | |
| case_id=case_id, | |
| primary_fault_party_id=result.get("primary_fault_party_id"), | |
| fault_distribution_json=json.dumps(result.get("fault_distribution", {})), | |
| probable_cause=result.get("probable_cause", ""), | |
| overall_confidence=result.get("overall_confidence", 0.0), | |
| analysis_summary=result.get("analysis_summary", ""), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to save fault analysis for case {case_id}: {e}") | |
| # Singleton | |
| fault_deducer = FaultDeducer() | |