import os from qdrant_client import QdrantClient from qdrant_client.models import NamedVector from typing import List from app.services.openai_service import OpenAIService, GeminiService from app.services.embeddings_service import EmbeddingsService, GeminiEmbeddingsService class RAGService: def __init__(self, qdrant_client: QdrantClient, embeddings_service: EmbeddingsService, ai_service: OpenAIService): self.qdrant_client = qdrant_client self.embeddings_service = embeddings_service self.ai_service = ai_service self.collection_name = os.getenv("QDRANT_COLLECTION_NAME", "book_embeddings") async def retrieve_context(self, query: str, top_k: int = 3) -> List[str]: # Handle both OpenAI and Gemini embeddings if isinstance(self.embeddings_service, GeminiEmbeddingsService): query_vector = await self.embeddings_service.create_embedding(query) else: query_vector = await self.embeddings_service.create_embedding(query) search_result = self.qdrant_client.search( collection_name=self.collection_name, query_vector=query_vector, limit=top_k, with_payload=True, ) context = [point.payload.get("content", "") for point in search_result if point.payload] return context async def generate_response(self, query: str, context: List[str]) -> str: full_prompt = f"""Context: {' '.join(context)} Question: {query} Answer:""" # Handle both OpenAI and Gemini services if isinstance(self.ai_service, GeminiService): response = await self.ai_service.get_chat_response(full_prompt) else: response = await self.ai_service.get_chat_response(full_prompt) return response