Spaces:
Sleeping
Sleeping
File size: 3,099 Bytes
6a11527 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
"""
Enhanced RAG Engine
Integrates scenario contextualization into RAG queries
"""
from typing import Tuple, Optional
from openai import OpenAI
from src.rag_query import RAGQueryEngine
from .feature_extractor import ADASFeatureExtractor
from ..retrieval.scenario_retriever import ScenarioRetriever
from ..formatting.constructive_formatter import ConstructiveFormatter
class EnhancedAnswer:
"""Enhanced answer with scenario context"""
def __init__(
self,
answer: str,
sources: str,
scenarios_html: Optional[str] = None,
scenario_count: int = 0
):
self.answer = answer
self.sources = sources
self.scenarios_html = scenarios_html
self.scenario_count = scenario_count
class EnhancedRAGEngine:
"""
Enhanced RAG engine with scenario contextualization integration
"""
def __init__(
self,
base_rag_engine: RAGQueryEngine,
scenario_retriever: ScenarioRetriever,
feature_extractor: ADASFeatureExtractor,
formatter: ConstructiveFormatter
):
self.base_rag = base_rag_engine
self.scenario_retriever = scenario_retriever
self.feature_extractor = feature_extractor
self.formatter = formatter
def query(
self,
query: str,
user_id: Optional[str] = None,
user_context: Optional[dict] = None
) -> EnhancedAnswer:
"""
Execute enhanced query
Args:
query: User query
user_id: User ID (optional, for personalization)
user_context: User context (optional)
Returns:
EnhancedAnswer: Contains standard answer and scenario context
"""
# 1. Standard RAG query
base_answer, sources = self.base_rag.query(query)
# 2. Extract ADAS features
adas_features = self.feature_extractor.extract(query)
# 3. Retrieve relevant scenarios (if related features found)
scenarios_html = None
scenario_count = 0
if adas_features:
try:
ranked_scenarios = self.scenario_retriever.retrieve(
query=query,
adas_features=adas_features,
max_results=3,
user_context=user_context
)
if ranked_scenarios:
# 4. Format scenarios
scenarios_html = self.formatter.format_scenarios_for_ui(ranked_scenarios)
scenario_count = len(ranked_scenarios)
except Exception as e:
print(f"⚠️ Error retrieving scenarios: {e}")
# Continue execution, scenario retrieval failure doesn't affect standard answer
return EnhancedAnswer(
answer=base_answer,
sources=sources,
scenarios_html=scenarios_html,
scenario_count=scenario_count
)
|