""" Scenario Database Management """ import json import os from pathlib import Path from typing import Dict, List, Optional from collections import defaultdict from .scenario_models import ADASScenario class ScenarioDatabase: """Scenario database management""" def __init__(self, data_file: str = "data/scenarios/initial_scenarios.json"): self.data_file = Path(data_file) self.scenarios: Dict[str, ADASScenario] = {} self.index = {} self._load_scenarios() self._build_index() def _load_scenarios(self): """Load scenarios from JSON file""" if not self.data_file.exists(): print(f"⚠️ Scenario database file not found: {self.data_file}") print(" Run 'python scripts/create_initial_scenarios.py' to create initial scenarios") return try: with open(self.data_file, 'r', encoding='utf-8') as f: data = json.load(f) scenarios_list = data.get("scenarios", []) for scenario_data in scenarios_list: scenario = ADASScenario.from_dict(scenario_data) self.scenarios[scenario.scenario_id] = scenario print(f"✅ Loaded {len(self.scenarios)} scenarios from {self.data_file}") except Exception as e: print(f"❌ Error loading scenarios: {e}") def _build_index(self): """Build index to speed up queries""" self.index = { 'by_feature': defaultdict(list), 'by_type': defaultdict(list), 'by_source': defaultdict(list), 'by_tag': defaultdict(list) } for scenario in self.scenarios.values(): self.index['by_feature'][scenario.adas_feature].append(scenario.scenario_id) self.index['by_type'][scenario.scenario_type].append(scenario.scenario_id) for source in scenario.source: self.index['by_source'][source].append(scenario.scenario_id) for tag in scenario.tags: self.index['by_tag'][tag].append(scenario.scenario_id) def get_by_id(self, scenario_id: str) -> Optional[ADASScenario]: """Get scenario by ID""" return self.scenarios.get(scenario_id) def get_all(self) -> List[ADASScenario]: """Get all scenarios""" return list(self.scenarios.values()) def filter_by_features(self, features: List[str]) -> List[ADASScenario]: """Filter scenarios by ADAS features""" if not features: return [] scenario_ids = set() for feature in features: if feature in self.index['by_feature']: scenario_ids.update(self.index['by_feature'][feature]) return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios] def filter_by_type(self, scenario_type: str) -> List[ADASScenario]: """Filter scenarios by type""" scenario_ids = self.index['by_type'].get(scenario_type, []) return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios] def full_text_search(self, query: str, top_k: int = 10) -> List[ADASScenario]: """Simple full-text search (based on keyword matching)""" query_lower = query.lower() query_words = set(query_lower.split()) scored_scenarios = [] for scenario in self.scenarios.values(): # Calculate match score text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower() text_words = set(text.split()) # Calculate intersection matches = len(query_words & text_words) if matches > 0: score = matches / len(query_words) # Simple match rate scored_scenarios.append((score, scenario)) # Sort and return top k scored_scenarios.sort(key=lambda x: x[0], reverse=True) return [s for _, s in scored_scenarios[:top_k]]