rag-api-node-1 / src /api /dependencies.py
Peterase's picture
feat: hybrid RAG pipeline upgrade
daf250b
from fastapi import Depends
from sqlalchemy.orm import Session
from src.infrastructure.database import get_db
# Adapters
from src.infrastructure.adapters.bge_embedder_adapter import BgeEmbedderAdapter
from src.infrastructure.adapters.qdrant_adapter import QdrantAdapter
from src.infrastructure.adapters.bge_reranker_adapter import BgeRerankerAdapter
from src.infrastructure.adapters.openai_adapter import OpenAiAdapter
from src.infrastructure.adapters.ollama_adapter import OllamaAdapter
from src.infrastructure.adapters.groq_adapter import GroqAdapter
from src.infrastructure.adapters.gemini_adapter import GeminiAdapter
from src.infrastructure.adapters.together_adapter import TogetherAdapter
from src.infrastructure.adapters.huggingface_adapter import HuggingFaceAdapter
from src.infrastructure.adapters.clickhouse_adapter import ClickHouseAdapter
from src.infrastructure.adapters.postgres_adapter import PostgresAdapter
from src.infrastructure.adapters.redis_adapter import RedisAdapter
from src.infrastructure.adapters.duckduckgo_adapter import DuckDuckGoAdapter
from src.infrastructure.adapters.newsapi_adapter import NewsAPIAdapter
# Hybrid Search Components
from src.core.orchestrator.query_orchestrator import QueryOrchestrator
from src.core.ranking.hybrid_result_ranker import HybridResultRanker
# Use Cases
from src.core.use_cases.search_use_case import SearchUseCase
from src.core.use_cases.rag_chat_use_case import RagChatUseCase
from src.core.use_cases.analytics_use_case import AnalyticsUseCase
# Global Singletons for Stateless Adapters to avoid reloading models per request
embedder_adapter = BgeEmbedderAdapter()
qdrant_adapter = QdrantAdapter()
reranker_adapter = BgeRerankerAdapter()
openai_adapter = OpenAiAdapter()
ollama_adapter = OllamaAdapter()
groq_adapter = GroqAdapter()
gemini_adapter = GeminiAdapter()
together_adapter = TogetherAdapter()
huggingface_adapter = HuggingFaceAdapter()
clickhouse_adapter = ClickHouseAdapter()
redis_adapter = RedisAdapter()
# Hybrid Search Singletons
from src.core.config import settings
duckduckgo_adapter = DuckDuckGoAdapter(
timeout=settings.LIVE_SEARCH_TIMEOUT,
max_results=settings.LIVE_SEARCH_MAX_RESULTS
)
newsapi_adapter = NewsAPIAdapter(
api_key=settings.NEWSAPI_KEY,
timeout=settings.NEWSAPI_TIMEOUT,
max_results=settings.NEWSAPI_MAX_RESULTS
) if settings.NEWSAPI_ENABLED else None
query_orchestrator = QueryOrchestrator(
live_search_adapter=duckduckgo_adapter,
newsapi_adapter=newsapi_adapter,
enable_hybrid=settings.ENABLE_HYBRID_SEARCH,
default_live_weight=settings.LIVE_SEARCH_WEIGHT,
default_db_weight=settings.DB_SEARCH_WEIGHT
)
hybrid_result_ranker = HybridResultRanker(reranker=reranker_adapter)
# Model Pre-warming (Triggered dynamically if needed, usually on startup)
def prewarm_models():
embedder_adapter._load_model()
reranker_adapter._load_model()
# --- Dependency Providers ---
def get_embedder_port():
return embedder_adapter
def get_vector_store_port():
return qdrant_adapter
def get_reranker_port():
return reranker_adapter
from src.core.config import settings
def get_llm_port():
provider = settings.LLM_PROVIDER.lower()
if provider == "groq":
return groq_adapter
elif provider == "gemini":
return gemini_adapter
elif provider == "together":
return together_adapter
elif provider == "huggingface" or provider == "hf":
return huggingface_adapter
elif provider == "ollama":
return ollama_adapter
return openai_adapter
def get_analytics_db_port():
return clickhouse_adapter
def get_chat_history_port(db: Session = Depends(get_db)):
return PostgresAdapter(db)
def get_cache_port():
return redis_adapter
def get_live_search_port():
return duckduckgo_adapter
def get_query_orchestrator():
return query_orchestrator
def get_hybrid_ranker():
return hybrid_result_ranker
# --- Use Case Providers ---
def get_search_use_case(
embedder=Depends(get_embedder_port),
vector_store=Depends(get_vector_store_port)
):
return SearchUseCase(embedder, vector_store)
def get_rag_chat_use_case(
embedder=Depends(get_embedder_port),
vector_store=Depends(get_vector_store_port),
reranker=Depends(get_reranker_port),
llm=Depends(get_llm_port),
chat_history=Depends(get_chat_history_port),
analytics_db=Depends(get_analytics_db_port),
cache=Depends(get_cache_port),
orchestrator=Depends(get_query_orchestrator),
hybrid_ranker=Depends(get_hybrid_ranker)
):
return RagChatUseCase(
embedder=embedder,
vector_store=vector_store,
reranker=reranker,
llm=llm,
chat_history_db=chat_history,
analytics_db=analytics_db,
cache=cache,
orchestrator=orchestrator,
hybrid_ranker=hybrid_ranker
)
from src.core.use_cases.account_use_case import AccountUseCase
from src.core.use_cases.agent_router_use_case import AgentRouterUseCase
def get_analytics_use_case(
analytics_db=Depends(get_analytics_db_port)
):
return AnalyticsUseCase(analytics_db)
def get_account_use_case():
return AccountUseCase()
def get_agent_router_use_case(
llm=Depends(get_llm_port),
rag_chat=Depends(get_rag_chat_use_case),
account=Depends(get_account_use_case),
chat_history=Depends(get_chat_history_port)
):
return AgentRouterUseCase(llm=llm, rag_chat=rag_chat, account=account, chat_history_db=chat_history)