Spaces:
No application file
No application file
| import os | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import NamedVector | |
| from typing import List | |
| from app.services.openai_service import OpenAIService, GeminiService | |
| from app.services.embeddings_service import EmbeddingsService, GeminiEmbeddingsService | |
| class RAGService: | |
| def __init__(self, qdrant_client: QdrantClient, embeddings_service: EmbeddingsService, ai_service: OpenAIService): | |
| self.qdrant_client = qdrant_client | |
| self.embeddings_service = embeddings_service | |
| self.ai_service = ai_service | |
| self.collection_name = os.getenv("QDRANT_COLLECTION_NAME", "book_embeddings") | |
| async def retrieve_context(self, query: str, top_k: int = 3) -> List[str]: | |
| # Handle both OpenAI and Gemini embeddings | |
| if isinstance(self.embeddings_service, GeminiEmbeddingsService): | |
| query_vector = await self.embeddings_service.create_embedding(query) | |
| else: | |
| query_vector = await self.embeddings_service.create_embedding(query) | |
| search_result = self.qdrant_client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_vector, | |
| limit=top_k, | |
| with_payload=True, | |
| ) | |
| context = [point.payload.get("content", "") for point in search_result if point.payload] | |
| return context | |
| async def generate_response(self, query: str, context: List[str]) -> str: | |
| full_prompt = f"""Context: {' '.join(context)} | |
| Question: {query} | |
| Answer:""" | |
| # Handle both OpenAI and Gemini services | |
| if isinstance(self.ai_service, GeminiService): | |
| response = await self.ai_service.get_chat_response(full_prompt) | |
| else: | |
| response = await self.ai_service.get_chat_response(full_prompt) | |
| return response |