File size: 5,463 Bytes
6a11527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""

Scenario Retriever

Retrieves relevant scenarios using semantic search and feature filtering

"""
from typing import List, Optional, Dict
from openai import OpenAI

from ..database.scenario_database import ScenarioDatabase
from ..database.scenario_models import ADASScenario, RankedScenario


class ScenarioRetriever:
    """Scenario retrieval engine"""
    
    def __init__(

        self,

        scenario_database: ScenarioDatabase,

        scenario_vector_store_id: Optional[str] = None,

        client: Optional[OpenAI] = None

    ):
        """

        Args:

            scenario_database: Scenario database

            scenario_vector_store_id: Scenario vector store ID (optional, if using semantic search)

            client: OpenAI client (if using vector search)

        """
        self.database = scenario_database
        self.vector_store_id = scenario_vector_store_id
        self.client = client
    
    def retrieve(

        self,

        query: str,

        adas_features: List[str],

        max_results: int = 3,

        user_context: Optional[Dict] = None

    ) -> List[RankedScenario]:
        """

        Retrieve relevant scenarios

        

        Args:

            query: User query

            adas_features: List of extracted ADAS features

            max_results: Maximum number of results to return

            user_context: User context (optional, for personalization)

        

        Returns:

            List[RankedScenario]: Sorted list of scenarios

        """
        # 1. Feature filtering
        feature_filtered = self.database.filter_by_features(adas_features)
        
        # 2. Full-text search (if feature filtering results are insufficient)
        if len(feature_filtered) < max_results:
            text_results = self.database.full_text_search(query, top_k=max_results * 2)
            # Merge results
            all_candidates = list(set(feature_filtered + text_results))
        else:
            all_candidates = feature_filtered
        
        # 3. Relevance scoring
        scored = self._score_relevance(all_candidates, query, adas_features)
        
        # 4. User context adjustment (if available)
        if user_context:
            scored = self._adjust_for_user_context(scored, user_context)
        
        # 5. Sort and return top N
        scored.sort(key=lambda x: x.relevance_score, reverse=True)
        return scored[:max_results]
    
    def _score_relevance(

        self,

        scenarios: List[ADASScenario],

        query: str,

        adas_features: List[str]

    ) -> List[RankedScenario]:
        """

        Calculate scenario relevance score

        

        Factors:

        1. Feature match (30%)

        2. Text similarity (40%)

        3. Scenario type weight (20%)

        4. Scenario quality score (10%)

        """
        ranked = []
        query_lower = query.lower()
        
        for scenario in scenarios:
            score = 0.0
            match_reasons = []
            
            # 1. Feature match (30%)
            feature_weight = 0.3
            if scenario.adas_feature in adas_features:
                feature_match = 1.0
                match_reasons.append(f"Matches feature: {scenario.adas_feature}")
            else:
                feature_match = 0.0
            score += feature_match * feature_weight
            
            # 2. Text similarity (40%)
            semantic_weight = 0.4
            # Simple keyword matching
            scenario_text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower()
            query_words = set(query_lower.split())
            scenario_words = set(scenario_text.split())
            common_words = query_words & scenario_words
            if query_words:
                text_similarity = len(common_words) / len(query_words)
                if text_similarity > 0.1:
                    match_reasons.append(f"Text similarity: {text_similarity:.2f}")
            else:
                text_similarity = 0.0
            score += text_similarity * semantic_weight
            
            # 3. Scenario type weight (20%)
            type_weight = 0.2
            type_weights = {
                "boundary_condition": 1.0,
                "historical_incident": 0.8,
                "hypothetical_edge_case": 0.9
            }
            type_score = type_weights.get(scenario.scenario_type, 0.5)
            score += type_score * type_weight
            
            # 4. Scenario quality score (10%)
            quality_weight = 0.1
            quality_score = scenario.metadata.quality_score
            score += quality_score * quality_weight
            
            ranked.append(RankedScenario(
                scenario=scenario,
                relevance_score=score,
                match_reasons=match_reasons
            ))
        
        return ranked
    
    def _adjust_for_user_context(

        self,

        ranked_scenarios: List[RankedScenario],

        user_context: Dict

    ) -> List[RankedScenario]:
        """

        Adjust relevance scores based on user context

        

        Example: If user is a beginner, prioritize basic scenarios

        """
        # Personalization adjustment logic can be implemented here
        # Currently returns as-is
        return ranked_scenarios