ishaq101's picture
[NOTICKET] add total token logging (#9)
d973099
raw
history blame
2.95 kB
"""Chatbot agent with RAG capabilities."""
import tiktoken
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from src.config.settings import settings
from src.middlewares.logging import get_logger
from langchain_core.messages import HumanMessage, AIMessage
logger = get_logger("chatbot")
_enc = tiktoken.get_encoding("cl100k_base")
def _count_tokens(messages: list, context: str) -> dict:
msg_tokens = sum(len(_enc.encode(m.content)) for m in messages)
ctx_tokens = len(_enc.encode(context))
return {"messages_tokens": msg_tokens, "context_tokens": ctx_tokens, "total": msg_tokens + ctx_tokens}
class ChatbotAgent:
"""Chatbot agent with RAG capabilities."""
def __init__(self):
self.llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0.7
)
# Read system prompt
try:
with open("src/config/agents/system_prompt.md", "r") as f:
system_prompt = f.read()
except FileNotFoundError:
system_prompt = "You are a helpful AI assistant with access to user's uploaded documents."
# Create prompt template
self.prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
("system", "Relevant documents:\n{context}")
])
# Create chain
self.chain = self.prompt | self.llm | StrOutputParser()
async def generate_response(
self,
messages: list,
context: str = ""
) -> str:
"""Generate response with optional RAG context."""
try:
logger.info("Generating chatbot response")
# Generate response
response = await self.chain.ainvoke({
"messages": messages,
"context": context
})
logger.info(f"Generated response: {response[:100]}...")
return response
except Exception as e:
logger.error("Response generation failed", error=str(e))
raise
async def astream_response(self, messages: list, context: str = ""):
"""Stream response tokens as they are generated."""
try:
token_counts = _count_tokens(messages, context)
logger.info("LLM input tokens", **token_counts)
async for token in self.chain.astream({"messages": messages, "context": context}):
yield token
except Exception as e:
logger.error("Response streaming failed", error=str(e))
raise
chatbot = ChatbotAgent()