Spaces:
Sleeping
Sleeping
| """ | |
| Scenario Retriever | |
| Retrieves relevant scenarios using semantic search and feature filtering | |
| """ | |
| from typing import List, Optional, Dict | |
| from openai import OpenAI | |
| from ..database.scenario_database import ScenarioDatabase | |
| from ..database.scenario_models import ADASScenario, RankedScenario | |
| class ScenarioRetriever: | |
| """Scenario retrieval engine""" | |
| def __init__( | |
| self, | |
| scenario_database: ScenarioDatabase, | |
| scenario_vector_store_id: Optional[str] = None, | |
| client: Optional[OpenAI] = None | |
| ): | |
| """ | |
| Args: | |
| scenario_database: Scenario database | |
| scenario_vector_store_id: Scenario vector store ID (optional, if using semantic search) | |
| client: OpenAI client (if using vector search) | |
| """ | |
| self.database = scenario_database | |
| self.vector_store_id = scenario_vector_store_id | |
| self.client = client | |
| def retrieve( | |
| self, | |
| query: str, | |
| adas_features: List[str], | |
| max_results: int = 3, | |
| user_context: Optional[Dict] = None | |
| ) -> List[RankedScenario]: | |
| """ | |
| Retrieve relevant scenarios | |
| Args: | |
| query: User query | |
| adas_features: List of extracted ADAS features | |
| max_results: Maximum number of results to return | |
| user_context: User context (optional, for personalization) | |
| Returns: | |
| List[RankedScenario]: Sorted list of scenarios | |
| """ | |
| # 1. Feature filtering | |
| feature_filtered = self.database.filter_by_features(adas_features) | |
| # 2. Full-text search (if feature filtering results are insufficient) | |
| if len(feature_filtered) < max_results: | |
| text_results = self.database.full_text_search(query, top_k=max_results * 2) | |
| # Merge results | |
| all_candidates = list(set(feature_filtered + text_results)) | |
| else: | |
| all_candidates = feature_filtered | |
| # 3. Relevance scoring | |
| scored = self._score_relevance(all_candidates, query, adas_features) | |
| # 4. User context adjustment (if available) | |
| if user_context: | |
| scored = self._adjust_for_user_context(scored, user_context) | |
| # 5. Sort and return top N | |
| scored.sort(key=lambda x: x.relevance_score, reverse=True) | |
| return scored[:max_results] | |
| def _score_relevance( | |
| self, | |
| scenarios: List[ADASScenario], | |
| query: str, | |
| adas_features: List[str] | |
| ) -> List[RankedScenario]: | |
| """ | |
| Calculate scenario relevance score | |
| Factors: | |
| 1. Feature match (30%) | |
| 2. Text similarity (40%) | |
| 3. Scenario type weight (20%) | |
| 4. Scenario quality score (10%) | |
| """ | |
| ranked = [] | |
| query_lower = query.lower() | |
| for scenario in scenarios: | |
| score = 0.0 | |
| match_reasons = [] | |
| # 1. Feature match (30%) | |
| feature_weight = 0.3 | |
| if scenario.adas_feature in adas_features: | |
| feature_match = 1.0 | |
| match_reasons.append(f"Matches feature: {scenario.adas_feature}") | |
| else: | |
| feature_match = 0.0 | |
| score += feature_match * feature_weight | |
| # 2. Text similarity (40%) | |
| semantic_weight = 0.4 | |
| # Simple keyword matching | |
| scenario_text = f"{scenario.title} {scenario.description} {' '.join(scenario.tags)}".lower() | |
| query_words = set(query_lower.split()) | |
| scenario_words = set(scenario_text.split()) | |
| common_words = query_words & scenario_words | |
| if query_words: | |
| text_similarity = len(common_words) / len(query_words) | |
| if text_similarity > 0.1: | |
| match_reasons.append(f"Text similarity: {text_similarity:.2f}") | |
| else: | |
| text_similarity = 0.0 | |
| score += text_similarity * semantic_weight | |
| # 3. Scenario type weight (20%) | |
| type_weight = 0.2 | |
| type_weights = { | |
| "boundary_condition": 1.0, | |
| "historical_incident": 0.8, | |
| "hypothetical_edge_case": 0.9 | |
| } | |
| type_score = type_weights.get(scenario.scenario_type, 0.5) | |
| score += type_score * type_weight | |
| # 4. Scenario quality score (10%) | |
| quality_weight = 0.1 | |
| quality_score = scenario.metadata.quality_score | |
| score += quality_score * quality_weight | |
| ranked.append(RankedScenario( | |
| scenario=scenario, | |
| relevance_score=score, | |
| match_reasons=match_reasons | |
| )) | |
| return ranked | |
| def _adjust_for_user_context( | |
| self, | |
| ranked_scenarios: List[RankedScenario], | |
| user_context: Dict | |
| ) -> List[RankedScenario]: | |
| """ | |
| Adjust relevance scores based on user context | |
| Example: If user is a beginner, prioritize basic scenarios | |
| """ | |
| # Personalization adjustment logic can be implemented here | |
| # Currently returns as-is | |
| return ranked_scenarios | |