| """Chatbot agent with RAG capabilities.""" |
|
|
| import tiktoken |
| from langchain_openai import AzureChatOpenAI |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.output_parsers import StrOutputParser |
| from src.config.settings import settings |
| from src.middlewares.logging import get_logger |
| from langchain_core.messages import HumanMessage, AIMessage |
|
|
| logger = get_logger("chatbot") |
|
|
| _enc = tiktoken.get_encoding("cl100k_base") |
|
|
|
|
| def _count_tokens(messages: list, context: str) -> dict: |
| msg_tokens = sum(len(_enc.encode(m.content)) for m in messages) |
| ctx_tokens = len(_enc.encode(context)) |
| return {"messages_tokens": msg_tokens, "context_tokens": ctx_tokens, "total": msg_tokens + ctx_tokens} |
|
|
|
|
| class ChatbotAgent: |
| """Chatbot agent with RAG capabilities.""" |
|
|
| def __init__(self): |
| self.llm = AzureChatOpenAI( |
| azure_deployment=settings.azureai_deployment_name_4o, |
| openai_api_version=settings.azureai_api_version_4o, |
| azure_endpoint=settings.azureai_endpoint_url_4o, |
| api_key=settings.azureai_api_key_4o, |
| temperature=0.7 |
| ) |
|
|
| |
| try: |
| with open("src/config/agents/system_prompt.md", "r") as f: |
| system_prompt = f.read() |
| except FileNotFoundError: |
| system_prompt = "You are a helpful AI assistant with access to user's uploaded documents." |
|
|
| |
| self.prompt = ChatPromptTemplate.from_messages([ |
| ("system", system_prompt), |
| MessagesPlaceholder(variable_name="messages"), |
| ("system", "Relevant documents:\n{context}") |
| ]) |
|
|
| |
| self.chain = self.prompt | self.llm | StrOutputParser() |
|
|
| async def generate_response( |
| self, |
| messages: list, |
| context: str = "" |
| ) -> str: |
| """Generate response with optional RAG context.""" |
| try: |
| logger.info("Generating chatbot response") |
|
|
| |
| response = await self.chain.ainvoke({ |
| "messages": messages, |
| "context": context |
| }) |
|
|
| logger.info(f"Generated response: {response[:100]}...") |
| return response |
|
|
| except Exception as e: |
| logger.error("Response generation failed", error=str(e)) |
| raise |
|
|
| async def astream_response(self, messages: list, context: str = ""): |
| """Stream response tokens as they are generated.""" |
| try: |
| token_counts = _count_tokens(messages, context) |
| logger.info("LLM input tokens", **token_counts) |
| async for token in self.chain.astream({"messages": messages, "context": context}): |
| yield token |
| except Exception as e: |
| logger.error("Response streaming failed", error=str(e)) |
| raise |
|
|
|
|
| chatbot = ChatbotAgent() |
|
|