Spaces:
Sleeping
Sleeping
| """ | |
| Scenario Database Management | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| from collections import defaultdict | |
| from .scenario_models import ADASScenario | |
| class ScenarioDatabase: | |
| """Scenario database management""" | |
| def __init__(self, data_file: str = "data/scenarios/initial_scenarios.json"): | |
| self.data_file = Path(data_file) | |
| self.scenarios: Dict[str, ADASScenario] = {} | |
| self.index = {} | |
| self._load_scenarios() | |
| self._build_index() | |
| def _load_scenarios(self): | |
| """Load scenarios from JSON file""" | |
| if not self.data_file.exists(): | |
| print(f"β οΈ Scenario database file not found: {self.data_file}") | |
| print(" Run 'python scripts/create_initial_scenarios.py' to create initial scenarios") | |
| return | |
| try: | |
| with open(self.data_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| scenarios_list = data.get("scenarios", []) | |
| for scenario_data in scenarios_list: | |
| scenario = ADASScenario.from_dict(scenario_data) | |
| self.scenarios[scenario.scenario_id] = scenario | |
| print(f"β Loaded {len(self.scenarios)} scenarios from {self.data_file}") | |
| except Exception as e: | |
| print(f"β Error loading scenarios: {e}") | |
| def _build_index(self): | |
| """Build index to speed up queries""" | |
| self.index = { | |
| 'by_feature': defaultdict(list), | |
| 'by_type': defaultdict(list), | |
| 'by_source': defaultdict(list), | |
| 'by_tag': defaultdict(list) | |
| } | |
| for scenario in self.scenarios.values(): | |
| self.index['by_feature'][scenario.adas_feature].append(scenario.scenario_id) | |
| self.index['by_type'][scenario.scenario_type].append(scenario.scenario_id) | |
| for source in scenario.source: | |
| self.index['by_source'][source].append(scenario.scenario_id) | |
| for tag in scenario.tags: | |
| self.index['by_tag'][tag].append(scenario.scenario_id) | |
| def get_by_id(self, scenario_id: str) -> Optional[ADASScenario]: | |
| """Get scenario by ID""" | |
| return self.scenarios.get(scenario_id) | |
| def get_all(self) -> List[ADASScenario]: | |
| """Get all scenarios""" | |
| return list(self.scenarios.values()) | |
| def filter_by_features(self, features: List[str]) -> List[ADASScenario]: | |
| """Filter scenarios by ADAS features""" | |
| if not features: | |
| return [] | |
| scenario_ids = set() | |
| for feature in features: | |
| if feature in self.index['by_feature']: | |
| scenario_ids.update(self.index['by_feature'][feature]) | |
| return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios] | |
| def filter_by_type(self, scenario_type: str) -> List[ADASScenario]: | |
| """Filter scenarios by type""" | |
| scenario_ids = self.index['by_type'].get(scenario_type, []) | |
| return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios] | |
| def full_text_search(self, query: str, top_k: int = 10) -> List[ADASScenario]: | |
| """Simple full-text search (based on keyword matching)""" | |
| query_lower = query.lower() | |
| query_words = set(query_lower.split()) | |
| scored_scenarios = [] | |
| for scenario in self.scenarios.values(): | |
| # Calculate match score | |
| text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower() | |
| text_words = set(text.split()) | |
| # Calculate intersection | |
| matches = len(query_words & text_words) | |
| if matches > 0: | |
| score = matches / len(query_words) # Simple match rate | |
| scored_scenarios.append((score, scenario)) | |
| # Sort and return top k | |
| scored_scenarios.sort(key=lambda x: x[0], reverse=True) | |
| return [s for _, s in scored_scenarios[:top_k]] | |