chatbot / app /services /rag_service.py
Tahasaif3's picture
'code'
a0c847a
import os
from qdrant_client import QdrantClient
from qdrant_client.models import NamedVector
from typing import List
from app.services.openai_service import OpenAIService, GeminiService
from app.services.embeddings_service import EmbeddingsService, GeminiEmbeddingsService
class RAGService:
def __init__(self, qdrant_client: QdrantClient, embeddings_service: EmbeddingsService, ai_service: OpenAIService):
self.qdrant_client = qdrant_client
self.embeddings_service = embeddings_service
self.ai_service = ai_service
self.collection_name = os.getenv("QDRANT_COLLECTION_NAME", "book_embeddings")
async def retrieve_context(self, query: str, top_k: int = 3) -> List[str]:
# Handle both OpenAI and Gemini embeddings
if isinstance(self.embeddings_service, GeminiEmbeddingsService):
query_vector = await self.embeddings_service.create_embedding(query)
else:
query_vector = await self.embeddings_service.create_embedding(query)
search_result = self.qdrant_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=top_k,
with_payload=True,
)
context = [point.payload.get("content", "") for point in search_result if point.payload]
return context
async def generate_response(self, query: str, context: List[str]) -> str:
full_prompt = f"""Context: {' '.join(context)}
Question: {query}
Answer:"""
# Handle both OpenAI and Gemini services
if isinstance(self.ai_service, GeminiService):
response = await self.ai_service.get_chat_response(full_prompt)
else:
response = await self.ai_service.get_chat_response(full_prompt)
return response