""" AI Accident Analysis — Traffic Rules Loader Loads and indexes traffic rules from YAML for fast lookup. Rules are loaded once at startup and cached in memory. """ import yaml from pathlib import Path from typing import List, Optional, Dict, Any from dataclasses import dataclass, field from backend.app.config import settings from backend.app.utils.logger import get_logger logger = get_logger("rule_loader") @dataclass class TrafficRule: """A single traffic rule with visual indicators.""" id: str title: str description: str severity: str # CRITICAL, HIGH, MEDIUM, LOW visual_indicators: List[str] fault_weight: float applicable_parties: List[str] category_id: str category_name: str @dataclass class RuleMatch: """A matched rule with confidence score.""" rule: TrafficRule confidence: float matched_indicators: List[str] evidence_text: str = "" class RuleLoader: """Load and index traffic rules from YAML for fast lookup.""" _instance = None _rules: List[TrafficRule] = [] _rules_by_id: Dict[str, TrafficRule] = {} _categories: List[Dict[str, Any]] = [] _all_indicators: List[str] = [] def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @property def is_loaded(self) -> bool: return len(self._rules) > 0 def load_rules(self, rules_path: Path = None): """Load all rules from YAML files in the rules directory.""" if rules_path is None: rules_path = settings.rules_path yaml_file = rules_path / "traffic_rules.yaml" if not yaml_file.exists(): logger.error(f"Rules file not found: {yaml_file}") return logger.info(f"Loading traffic rules from: {yaml_file}") with open(yaml_file, "r") as f: data = yaml.safe_load(f) self._rules = [] self._rules_by_id = {} self._categories = [] self._all_indicators = [] for category in data.get("categories", []): cat_id = category["id"] cat_name = category["name"] self._categories.append({"id": cat_id, "name": cat_name}) for rule_data in category.get("rules", []): rule = TrafficRule( id=rule_data["id"], title=rule_data["title"], description=rule_data["description"], severity=rule_data.get("severity", "MEDIUM"), visual_indicators=rule_data.get("visual_indicators", []), fault_weight=rule_data.get("fault_weight", 0.5), applicable_parties=rule_data.get("applicable_parties", []), category_id=cat_id, category_name=cat_name, ) self._rules.append(rule) self._rules_by_id[rule.id] = rule self._all_indicators.extend(rule.visual_indicators) logger.info( f"Loaded {len(self._rules)} rules across " f"{len(self._categories)} categories with " f"{len(self._all_indicators)} visual indicators" ) def get_all_rules(self) -> List[TrafficRule]: """Return all loaded rules.""" return self._rules def get_rule_by_id(self, rule_id: str) -> Optional[TrafficRule]: """Look up a rule by its ID.""" return self._rules_by_id.get(rule_id) def get_rules_by_category(self, category_id: str) -> List[TrafficRule]: """Get all rules in a category.""" return [r for r in self._rules if r.category_id == category_id] def get_categories(self) -> List[Dict[str, Any]]: """Get list of all categories.""" return self._categories def get_all_visual_indicators(self) -> List[str]: """Get all visual indicators for prompt construction.""" return self._all_indicators def get_rules_summary(self) -> Dict[str, Any]: """Get a summary of loaded rules for the API.""" summary = { "total_rules": len(self._rules), "categories": [] } for cat in self._categories: cat_rules = self.get_rules_by_category(cat["id"]) summary["categories"].append({ "id": cat["id"], "name": cat["name"], "rule_count": len(cat_rules), "rules": [ {"id": r.id, "title": r.title, "severity": r.severity} for r in cat_rules ] }) return summary # Singleton instance rule_loader = RuleLoader()