Bryceeee's picture
Upload 34 files
6a11527 verified
"""
Scenario Retriever
Retrieves relevant scenarios using semantic search and feature filtering
"""
from typing import List, Optional, Dict
from openai import OpenAI
from ..database.scenario_database import ScenarioDatabase
from ..database.scenario_models import ADASScenario, RankedScenario
class ScenarioRetriever:
"""Scenario retrieval engine"""
def __init__(
self,
scenario_database: ScenarioDatabase,
scenario_vector_store_id: Optional[str] = None,
client: Optional[OpenAI] = None
):
"""
Args:
scenario_database: Scenario database
scenario_vector_store_id: Scenario vector store ID (optional, if using semantic search)
client: OpenAI client (if using vector search)
"""
self.database = scenario_database
self.vector_store_id = scenario_vector_store_id
self.client = client
def retrieve(
self,
query: str,
adas_features: List[str],
max_results: int = 3,
user_context: Optional[Dict] = None
) -> List[RankedScenario]:
"""
Retrieve relevant scenarios
Args:
query: User query
adas_features: List of extracted ADAS features
max_results: Maximum number of results to return
user_context: User context (optional, for personalization)
Returns:
List[RankedScenario]: Sorted list of scenarios
"""
# 1. Feature filtering
feature_filtered = self.database.filter_by_features(adas_features)
# 2. Full-text search (if feature filtering results are insufficient)
if len(feature_filtered) < max_results:
text_results = self.database.full_text_search(query, top_k=max_results * 2)
# Merge results
all_candidates = list(set(feature_filtered + text_results))
else:
all_candidates = feature_filtered
# 3. Relevance scoring
scored = self._score_relevance(all_candidates, query, adas_features)
# 4. User context adjustment (if available)
if user_context:
scored = self._adjust_for_user_context(scored, user_context)
# 5. Sort and return top N
scored.sort(key=lambda x: x.relevance_score, reverse=True)
return scored[:max_results]
def _score_relevance(
self,
scenarios: List[ADASScenario],
query: str,
adas_features: List[str]
) -> List[RankedScenario]:
"""
Calculate scenario relevance score
Factors:
1. Feature match (30%)
2. Text similarity (40%)
3. Scenario type weight (20%)
4. Scenario quality score (10%)
"""
ranked = []
query_lower = query.lower()
for scenario in scenarios:
score = 0.0
match_reasons = []
# 1. Feature match (30%)
feature_weight = 0.3
if scenario.adas_feature in adas_features:
feature_match = 1.0
match_reasons.append(f"Matches feature: {scenario.adas_feature}")
else:
feature_match = 0.0
score += feature_match * feature_weight
# 2. Text similarity (40%)
semantic_weight = 0.4
# Simple keyword matching
scenario_text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower()
query_words = set(query_lower.split())
scenario_words = set(scenario_text.split())
common_words = query_words & scenario_words
if query_words:
text_similarity = len(common_words) / len(query_words)
if text_similarity > 0.1:
match_reasons.append(f"Text similarity: {text_similarity:.2f}")
else:
text_similarity = 0.0
score += text_similarity * semantic_weight
# 3. Scenario type weight (20%)
type_weight = 0.2
type_weights = {
"boundary_condition": 1.0,
"historical_incident": 0.8,
"hypothetical_edge_case": 0.9
}
type_score = type_weights.get(scenario.scenario_type, 0.5)
score += type_score * type_weight
# 4. Scenario quality score (10%)
quality_weight = 0.1
quality_score = scenario.metadata.quality_score
score += quality_score * quality_weight
ranked.append(RankedScenario(
scenario=scenario,
relevance_score=score,
match_reasons=match_reasons
))
return ranked
def _adjust_for_user_context(
self,
ranked_scenarios: List[RankedScenario],
user_context: Dict
) -> List[RankedScenario]:
"""
Adjust relevance scores based on user context
Example: If user is a beginner, prioritize basic scenarios
"""
# Personalization adjustment logic can be implemented here
# Currently returns as-is
return ranked_scenarios