chat-robot / app /services /rag_service.py
MuhammadSaad16's picture
Add application file
0cee4dc
# import os
# from qdrant_client import QdrantClient
# from qdrant_client.models import NamedVector
# from typing import List
# from app.services.openai_service import OpenAIService
# from app.services.embeddings_service import EmbeddingsService
# class RAGService:
# def __init__(self, qdrant_client: QdrantClient, embeddings_service: EmbeddingsService, gemini_service: OpenAIService):
# self.qdrant_client = qdrant_client
# self.embeddings_service = embeddings_service
# self.gemini_service = gemini_service
# self.collection_name = os.getenv("QDRANT_COLLECTION_NAME", "book_embeddings")
# async def retrieve_context(self, query: str, top_k: int = 3) -> List[str]:
# query_vector = self.embeddings_service.create_embedding(query)
# search_result = self.qdrant_client.query_points(
# collection_name=self.collection_name,
# query=query_vector,
# limit=top_k,
# with_payload=True,
# ).points
# context = [point.payload.get("content", "") for point in search_result if point.payload]
# return context
# async def generate_response(self, query: str, context: List[str]) -> str:
# full_prompt = f"""Context: {' '.join(context)}
# Question: {query}
# Answer:"""
# response = await self.gemini_service.get_chat_response(full_prompt)
# return response
import os
from qdrant_client import QdrantClient
from qdrant_client.models import NamedVector
from typing import List
from app.services.openai_service import OpenAIService
from app.services.embeddings_service import EmbeddingsService
class RAGService:
def __init__(self, qdrant_client: QdrantClient, embeddings_service: EmbeddingsService, gemini_service: OpenAIService):
self.qdrant_client = qdrant_client
self.embeddings_service = embeddings_service
self.gemini_service = gemini_service
self.collection_name = os.getenv("QDRANT_COLLECTION_NAME", "book_embeddings")
async def retrieve_context(self, query: str, top_k: int = 3) -> List[str]:
query_vector = self.embeddings_service.create_embedding(query)
search_result = self.qdrant_client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=top_k,
with_payload=True,
).points
context = [point.payload.get("content", "") for point in search_result if point.payload]
return context
async def generate_response(self, query: str, context: List[str]) -> str:
full_prompt = f"""Context: {' '.join(context)}
Question: {query}
Answer:"""
response = await self.gemini_service.get_chat_response(full_prompt)
return response