Spaces:
Sleeping
Sleeping
File size: 5,463 Bytes
6a11527 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""
Scenario Retriever
Retrieves relevant scenarios using semantic search and feature filtering
"""
from typing import List, Optional, Dict
from openai import OpenAI
from ..database.scenario_database import ScenarioDatabase
from ..database.scenario_models import ADASScenario, RankedScenario
class ScenarioRetriever:
"""Scenario retrieval engine"""
def __init__(
self,
scenario_database: ScenarioDatabase,
scenario_vector_store_id: Optional[str] = None,
client: Optional[OpenAI] = None
):
"""
Args:
scenario_database: Scenario database
scenario_vector_store_id: Scenario vector store ID (optional, if using semantic search)
client: OpenAI client (if using vector search)
"""
self.database = scenario_database
self.vector_store_id = scenario_vector_store_id
self.client = client
def retrieve(
self,
query: str,
adas_features: List[str],
max_results: int = 3,
user_context: Optional[Dict] = None
) -> List[RankedScenario]:
"""
Retrieve relevant scenarios
Args:
query: User query
adas_features: List of extracted ADAS features
max_results: Maximum number of results to return
user_context: User context (optional, for personalization)
Returns:
List[RankedScenario]: Sorted list of scenarios
"""
# 1. Feature filtering
feature_filtered = self.database.filter_by_features(adas_features)
# 2. Full-text search (if feature filtering results are insufficient)
if len(feature_filtered) < max_results:
text_results = self.database.full_text_search(query, top_k=max_results * 2)
# Merge results
all_candidates = list(set(feature_filtered + text_results))
else:
all_candidates = feature_filtered
# 3. Relevance scoring
scored = self._score_relevance(all_candidates, query, adas_features)
# 4. User context adjustment (if available)
if user_context:
scored = self._adjust_for_user_context(scored, user_context)
# 5. Sort and return top N
scored.sort(key=lambda x: x.relevance_score, reverse=True)
return scored[:max_results]
def _score_relevance(
self,
scenarios: List[ADASScenario],
query: str,
adas_features: List[str]
) -> List[RankedScenario]:
"""
Calculate scenario relevance score
Factors:
1. Feature match (30%)
2. Text similarity (40%)
3. Scenario type weight (20%)
4. Scenario quality score (10%)
"""
ranked = []
query_lower = query.lower()
for scenario in scenarios:
score = 0.0
match_reasons = []
# 1. Feature match (30%)
feature_weight = 0.3
if scenario.adas_feature in adas_features:
feature_match = 1.0
match_reasons.append(f"Matches feature: {scenario.adas_feature}")
else:
feature_match = 0.0
score += feature_match * feature_weight
# 2. Text similarity (40%)
semantic_weight = 0.4
# Simple keyword matching
scenario_text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower()
query_words = set(query_lower.split())
scenario_words = set(scenario_text.split())
common_words = query_words & scenario_words
if query_words:
text_similarity = len(common_words) / len(query_words)
if text_similarity > 0.1:
match_reasons.append(f"Text similarity: {text_similarity:.2f}")
else:
text_similarity = 0.0
score += text_similarity * semantic_weight
# 3. Scenario type weight (20%)
type_weight = 0.2
type_weights = {
"boundary_condition": 1.0,
"historical_incident": 0.8,
"hypothetical_edge_case": 0.9
}
type_score = type_weights.get(scenario.scenario_type, 0.5)
score += type_score * type_weight
# 4. Scenario quality score (10%)
quality_weight = 0.1
quality_score = scenario.metadata.quality_score
score += quality_score * quality_weight
ranked.append(RankedScenario(
scenario=scenario,
relevance_score=score,
match_reasons=match_reasons
))
return ranked
def _adjust_for_user_context(
self,
ranked_scenarios: List[RankedScenario],
user_context: Dict
) -> List[RankedScenario]:
"""
Adjust relevance scores based on user context
Example: If user is a beginner, prioritize basic scenarios
"""
# Personalization adjustment logic can be implemented here
# Currently returns as-is
return ranked_scenarios
|