chatbot / app /services /agent_service.py
Tahasaif3's picture
Update app/services/agent_service.py
396ee10 verified
import os
import re
import uuid
import httpx
import asyncio
import logging
from agents import AsyncOpenAI, OpenAIChatCompletionsModel, set_tracing_disabled
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
from app.core.config import settings
from app.services.vector_store import vector_store
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
class RAGAgent:
def __init__(self):
# Use Gemini API key from settings
gemini_api_key = settings.GEMINI_API_KEY
set_tracing_disabled(True)
# Initialize Gemini provider with OpenAI-compatible interface
self.provider = AsyncOpenAI(
api_key=gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
self.model = OpenAIChatCompletionsModel(
model=settings.LLM_MODEL, # Use model from settings
openai_client=self.provider,
)
# For embeddings, use synchronous client
self.embedding_client = AsyncOpenAI(
api_key=gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
# Initialize the vector store
vector_store.create_collection(force_recreate=True)
# async def generate_embedding(self, text: str) -> List[float]:
# """Generate embedding for a given text using Gemini's embedding model
# Args:
# text: The text to generate embedding for
# Returns:
# The embedding vector (768 dimensions for text-embedding-004)
# """
# response = await self.embedding_client.embeddings.create(
# input=text,
# model=settings.EMBEDDING_MODEL # Use Gemini's embedding model
# )
# embedding = response.data[0].embedding
# # Validate dimension matches settings
# if len(embedding) != settings.EMBEDDING_DIMENSION:
# raise ValueError(
# f"Embedding dimension mismatch! "
# f"Gemini '{settings.EMBEDDING_MODEL}' returned {len(embedding)} dims, "
# f"but settings.EMBEDDING_DIMENSION is {settings.EMBEDDING_DIMENSION}"
# )
# return embedding
async def generate_embedding(self, text: str) -> List[float]:
url = f"https://generativelanguage.googleapis.com/v1beta/models/{settings.EMBEDDING_MODEL}:embedContent"
headers = {
"Content-Type": "application/json",
}
params = {
"key": settings.GEMINI_API_KEY
}
body = {
"content": {
"parts": [{"text": text}]
}
}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, params=params, json=body)
response.raise_for_status()
data = response.json()
embedding = data["embedding"]["values"]
if len(embedding) != settings.EMBEDDING_DIMENSION:
raise ValueError(
f"Embedding dimension mismatch: got {len(embedding)}"
)
return embedding
async def retrieve_relevant_documents(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Retrieve relevant documents from Qdrant based on the query
Args:
query: The user query
limit: Maximum number of documents to retrieve
Returns:
List of relevant documents with their payloads and scores
"""
# Generate embedding for the query using Gemini
query_embedding = await self.generate_embedding(query)
# Search for relevant documents
documents = vector_store.search_documents(query_embedding, limit=limit)
return documents
async def process_query(self, user_query: str, context: Optional[str] = None,
selected_text: Optional[str] = None) -> Dict[str, Any]:
"""Process user query with RAG capabilities
Args:
user_query: The user's question
context: Additional context from chat history
selected_text: Text selected by the user from the book
Returns:
Dictionary containing the response and sources
"""
# Retrieve relevant documents
relevant_docs = await self.retrieve_relevant_documents(user_query)
# Log the number of documents found
logger.info(f"Found {len(relevant_docs)} relevant documents for query: {user_query}")
# Construct the prompt with retrieved context
prompt = self._construct_prompt(user_query, relevant_docs, context, selected_text)
# Log the prompt for debugging
logger.info(f"Prompt sent to LLM: {prompt[:200]}...") # First 200 chars
try:
# Call Gemini via OpenAI-compatible interface
response = await self.provider.chat.completions.create(
model=settings.LLM_MODEL, # Use Gemini model
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions about Physical AI and Humanoid Robotics. Use the provided context to answer the user's question."},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=1500
)
# Log the raw response for debugging
logger.info(f"Raw LLM response: {response}")
# Extract answer from response
answer = response.choices[0].message.content if response.choices and response.choices[0].message else ""
# If no answer was generated, provide a fallback response
if not answer:
logger.warning("LLM returned empty response, using fallback message")
answer = "I couldn't generate a specific answer for your query. Please try rephrasing your question."
except Exception as e:
logger.error(f"Error calling LLM: {str(e)}")
answer = f"Sorry, I encountered an error while processing your query: {str(e)}"
# Format sources
sources = []
for doc in relevant_docs:
if doc.get("payload"):
source = {
"title": doc["payload"].get("title", ""),
"chapter": doc["payload"].get("chapter", ""),
"section": doc["payload"].get("section", ""),
"score": doc.get("score", 0)
}
sources.append(source)
return {
"response": answer,
"sources": sources
}
def _construct_prompt(self, query: str, documents: List[Dict[str, Any]],
context: Optional[str], selected_text: Optional[str]) -> str:
"""Construct the prompt for the LLM with retrieved context
Args:
query: The user's question
documents: Retrieved documents
context: Additional context from chat history
selected_text: Text selected by the user
Returns:
Constructed prompt string
"""
prompt = f"Question: {query}\n\n"
if selected_text:
prompt += f"Selected Text Context: {selected_text}\n\n"
if context:
prompt += f"Conversation Context: {context}\n\n"
prompt += "Relevant Information:\n"
if documents:
for i, doc in enumerate(documents, 1):
if doc.get("payload"):
payload = doc["payload"]
prompt += f"{i}. Title: {payload.get('title', '')}\n"
if payload.get('chapter'):
prompt += f" Chapter: {payload.get('chapter')}\n"
if payload.get('section'):
prompt += f" Section: {payload.get('section')}\n"
prompt += f" Content: {payload.get('content', '')}\n\n"
else:
prompt += "No relevant information found.\n\n"
prompt += "Please provide a comprehensive answer to the question using the provided information. "
prompt += "Cite relevant sections when appropriate and explain concepts clearly. "
prompt += "If the provided information is not relevant or sufficient to answer the question, please state that."
return prompt
class RAGPipeline:
"""RAG Pipeline for document ingestion and query processing"""
def __init__(self):
self.rag_agent = rag_agent
def preprocess_query(self, query: str) -> str:
"""Preprocess the user query by cleaning and normalizing it
Args:
query: The raw user query
Returns:
Cleaned and normalized query
"""
# Remove extra whitespace
query = re.sub(r'\s+', ' ', query.strip())
# Remove special characters but keep punctuation
query = re.sub(r'[^\w\s\.\,\!\?\;\:]', '', query)
return query
def extract_keywords(self, query: str) -> List[str]:
"""Extract keywords from the query for better retrieval
Args:
query: The user query
Returns:
List of extracted keywords
"""
# Simple keyword extraction
stop_words = {
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have',
'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should',
'may', 'might', 'must', 'can'
}
# Tokenize and remove stop words
words = query.lower().split()
keywords = [word for word in words if word not in stop_words]
return keywords
async def generate_response(self, query: str, context: Optional[str] = None,
selected_text: Optional[str] = None) -> Dict[str, Any]:
"""Generate a response using the complete RAG pipeline
Args:
query: The user's question
context: Additional context from chat history
selected_text: Text selected by the user from the book
Returns:
Dictionary containing the response and sources
"""
# Preprocess the query
cleaned_query = self.preprocess_query(query)
# Process the query with the RAG agent
result = await self.rag_agent.process_query(cleaned_query, context, selected_text)
return result
async def ingest_document(self, document: Dict[str, Any]) -> str:
"""Ingest a document into the vector store
Args:
document: Dictionary containing document data
Returns:
ID of the ingested document
"""
# Generate embedding for the document content
embedding = await self.rag_agent.generate_embedding(document["content"])
# Prepare document for storage
doc_id = document.get("id")
if not doc_id:
doc_id = str(uuid.uuid4())
# Create payload
payload = {
"document_id": doc_id,
"title": document["title"],
"content": document["content"],
"chapter": document.get("chapter", ""),
"section": document.get("section", ""),
"subsection": document.get("subsection", "")
}
# Add to vector store
documents = [{
"id": doc_id,
"vector": embedding,
"payload": payload
}]
added_ids = vector_store.add_documents(documents)
return added_ids[0] if added_ids else ""
async def ingest_documents_batch(self, documents: List[Dict[str, Any]]) -> List[str]:
"""Ingest multiple documents into the vector store efficiently
Args:
documents: List of document dictionaries
Returns:
List of ingested document IDs
"""
# Process embeddings concurrently for better performance
embedding_tasks = [
self.rag_agent.generate_embedding(doc["content"])
for doc in documents
]
embeddings = await asyncio.gather(*embedding_tasks)
# Prepare documents for storage
docs_to_add = []
doc_ids = []
for doc, embedding in zip(documents, embeddings):
doc_id = doc.get("id")
if not doc_id:
doc_id = str(uuid.uuid4())
doc_ids.append(doc_id)
payload = {
"document_id": doc_id,
"title": doc["title"],
"content": doc["content"],
"chapter": doc.get("chapter", ""),
"section": doc.get("section", ""),
"subsection": doc.get("subsection", "")
}
docs_to_add.append({
"id": doc_id,
"vector": embedding,
"payload": payload
})
# Add all documents to vector store
added_ids = vector_store.add_documents(docs_to_add)
return added_ids
# Initialize the RAG agent
rag_agent = RAGAgent()
# Initialize the RAG pipeline
rag_pipeline = RAGPipeline()