Bryceeee's picture
Upload 34 files
6a11527 verified
"""
Scenario Database Management
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from collections import defaultdict
from .scenario_models import ADASScenario
class ScenarioDatabase:
"""Scenario database management"""
def __init__(self, data_file: str = "data/scenarios/initial_scenarios.json"):
self.data_file = Path(data_file)
self.scenarios: Dict[str, ADASScenario] = {}
self.index = {}
self._load_scenarios()
self._build_index()
def _load_scenarios(self):
"""Load scenarios from JSON file"""
if not self.data_file.exists():
print(f"⚠️ Scenario database file not found: {self.data_file}")
print(" Run 'python scripts/create_initial_scenarios.py' to create initial scenarios")
return
try:
with open(self.data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
scenarios_list = data.get("scenarios", [])
for scenario_data in scenarios_list:
scenario = ADASScenario.from_dict(scenario_data)
self.scenarios[scenario.scenario_id] = scenario
print(f"βœ… Loaded {len(self.scenarios)} scenarios from {self.data_file}")
except Exception as e:
print(f"❌ Error loading scenarios: {e}")
def _build_index(self):
"""Build index to speed up queries"""
self.index = {
'by_feature': defaultdict(list),
'by_type': defaultdict(list),
'by_source': defaultdict(list),
'by_tag': defaultdict(list)
}
for scenario in self.scenarios.values():
self.index['by_feature'][scenario.adas_feature].append(scenario.scenario_id)
self.index['by_type'][scenario.scenario_type].append(scenario.scenario_id)
for source in scenario.source:
self.index['by_source'][source].append(scenario.scenario_id)
for tag in scenario.tags:
self.index['by_tag'][tag].append(scenario.scenario_id)
def get_by_id(self, scenario_id: str) -> Optional[ADASScenario]:
"""Get scenario by ID"""
return self.scenarios.get(scenario_id)
def get_all(self) -> List[ADASScenario]:
"""Get all scenarios"""
return list(self.scenarios.values())
def filter_by_features(self, features: List[str]) -> List[ADASScenario]:
"""Filter scenarios by ADAS features"""
if not features:
return []
scenario_ids = set()
for feature in features:
if feature in self.index['by_feature']:
scenario_ids.update(self.index['by_feature'][feature])
return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios]
def filter_by_type(self, scenario_type: str) -> List[ADASScenario]:
"""Filter scenarios by type"""
scenario_ids = self.index['by_type'].get(scenario_type, [])
return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios]
def full_text_search(self, query: str, top_k: int = 10) -> List[ADASScenario]:
"""Simple full-text search (based on keyword matching)"""
query_lower = query.lower()
query_words = set(query_lower.split())
scored_scenarios = []
for scenario in self.scenarios.values():
# Calculate match score
text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower()
text_words = set(text.split())
# Calculate intersection
matches = len(query_words & text_words)
if matches > 0:
score = matches / len(query_words) # Simple match rate
scored_scenarios.append((score, scenario))
# Sort and return top k
scored_scenarios.sort(key=lambda x: x[0], reverse=True)
return [s for _, s in scored_scenarios[:top_k]]