File size: 4,200 Bytes
6a11527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""

Scenario Database Management

"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from collections import defaultdict

from .scenario_models import ADASScenario


class ScenarioDatabase:
    """Scenario database management"""
    
    def __init__(self, data_file: str = "data/scenarios/initial_scenarios.json"):
        self.data_file = Path(data_file)
        self.scenarios: Dict[str, ADASScenario] = {}
        self.index = {}
        self._load_scenarios()
        self._build_index()
    
    def _load_scenarios(self):
        """Load scenarios from JSON file"""
        if not self.data_file.exists():
            print(f"⚠️ Scenario database file not found: {self.data_file}")
            print("   Run 'python scripts/create_initial_scenarios.py' to create initial scenarios")
            return
        
        try:
            with open(self.data_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            scenarios_list = data.get("scenarios", [])
            for scenario_data in scenarios_list:
                scenario = ADASScenario.from_dict(scenario_data)
                self.scenarios[scenario.scenario_id] = scenario
            
            print(f"✅ Loaded {len(self.scenarios)} scenarios from {self.data_file}")
        except Exception as e:
            print(f"❌ Error loading scenarios: {e}")
    
    def _build_index(self):
        """Build index to speed up queries"""
        self.index = {
            'by_feature': defaultdict(list),
            'by_type': defaultdict(list),
            'by_source': defaultdict(list),
            'by_tag': defaultdict(list)
        }
        
        for scenario in self.scenarios.values():
            self.index['by_feature'][scenario.adas_feature].append(scenario.scenario_id)
            self.index['by_type'][scenario.scenario_type].append(scenario.scenario_id)
            for source in scenario.source:
                self.index['by_source'][source].append(scenario.scenario_id)
            for tag in scenario.tags:
                self.index['by_tag'][tag].append(scenario.scenario_id)
    
    def get_by_id(self, scenario_id: str) -> Optional[ADASScenario]:
        """Get scenario by ID"""
        return self.scenarios.get(scenario_id)
    
    def get_all(self) -> List[ADASScenario]:
        """Get all scenarios"""
        return list(self.scenarios.values())
    
    def filter_by_features(self, features: List[str]) -> List[ADASScenario]:
        """Filter scenarios by ADAS features"""
        if not features:
            return []
        
        scenario_ids = set()
        for feature in features:
            if feature in self.index['by_feature']:
                scenario_ids.update(self.index['by_feature'][feature])
        
        return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios]
    
    def filter_by_type(self, scenario_type: str) -> List[ADASScenario]:
        """Filter scenarios by type"""
        scenario_ids = self.index['by_type'].get(scenario_type, [])
        return [self.scenarios[sid] for sid in scenario_ids if sid in self.scenarios]
    
    def full_text_search(self, query: str, top_k: int = 10) -> List[ADASScenario]:
        """Simple full-text search (based on keyword matching)"""
        query_lower = query.lower()
        query_words = set(query_lower.split())
        
        scored_scenarios = []
        for scenario in self.scenarios.values():
            # Calculate match score
            text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower()
            text_words = set(text.split())
            
            # Calculate intersection
            matches = len(query_words & text_words)
            if matches > 0:
                score = matches / len(query_words)  # Simple match rate
                scored_scenarios.append((score, scenario))
        
        # Sort and return top k
        scored_scenarios.sort(key=lambda x: x[0], reverse=True)
        return [s for _, s in scored_scenarios[:top_k]]