Bryceeee's picture
Upload 34 files
6a11527 verified
raw
history blame
3.1 kB
"""
Enhanced RAG Engine
Integrates scenario contextualization into RAG queries
"""
from typing import Tuple, Optional
from openai import OpenAI
from src.rag_query import RAGQueryEngine
from .feature_extractor import ADASFeatureExtractor
from ..retrieval.scenario_retriever import ScenarioRetriever
from ..formatting.constructive_formatter import ConstructiveFormatter
class EnhancedAnswer:
"""Enhanced answer with scenario context"""
def __init__(
self,
answer: str,
sources: str,
scenarios_html: Optional[str] = None,
scenario_count: int = 0
):
self.answer = answer
self.sources = sources
self.scenarios_html = scenarios_html
self.scenario_count = scenario_count
class EnhancedRAGEngine:
"""
Enhanced RAG engine with scenario contextualization integration
"""
def __init__(
self,
base_rag_engine: RAGQueryEngine,
scenario_retriever: ScenarioRetriever,
feature_extractor: ADASFeatureExtractor,
formatter: ConstructiveFormatter
):
self.base_rag = base_rag_engine
self.scenario_retriever = scenario_retriever
self.feature_extractor = feature_extractor
self.formatter = formatter
def query(
self,
query: str,
user_id: Optional[str] = None,
user_context: Optional[dict] = None
) -> EnhancedAnswer:
"""
Execute enhanced query
Args:
query: User query
user_id: User ID (optional, for personalization)
user_context: User context (optional)
Returns:
EnhancedAnswer: Contains standard answer and scenario context
"""
# 1. Standard RAG query
base_answer, sources = self.base_rag.query(query)
# 2. Extract ADAS features
adas_features = self.feature_extractor.extract(query)
# 3. Retrieve relevant scenarios (if related features found)
scenarios_html = None
scenario_count = 0
if adas_features:
try:
ranked_scenarios = self.scenario_retriever.retrieve(
query=query,
adas_features=adas_features,
max_results=3,
user_context=user_context
)
if ranked_scenarios:
# 4. Format scenarios
scenarios_html = self.formatter.format_scenarios_for_ui(ranked_scenarios)
scenario_count = len(ranked_scenarios)
except Exception as e:
print(f"⚠️ Error retrieving scenarios: {e}")
# Continue execution, scenario retrieval failure doesn't affect standard answer
return EnhancedAnswer(
answer=base_answer,
sources=sources,
scenarios_html=scenarios_html,
scenario_count=scenario_count
)