cs-ai-sakura-dev / src /internal /agents /gpt_executor_agent.py
lifedebugger's picture
Deploy files from GitHub repository
df01f5b
import logging
from datetime import datetime
from src.internal.agents.base_agents import Agent, AgentRequest
from src.internal.rag.inference.inferencer import Inferencer
from src.internal.rag.retriever.langchain_retriever import LangChainRetriever
from src.internal.rag.web_search.duckduckgo_search import DDGS
from typing import List, Dict
from langchain_core.documents import Document
import copy
from src.internal.rag.inference.inferencer_types import (
chunk_response, error_response, complete_response, meta_data_response
)
from openai import OpenAI
class GPTExecutorAgent(Agent):
def __init__(self, ddgs : DDGS, retriever : LangChainRetriever, openai_client : OpenAI):
super().__init__(None)
self.openai_client = openai_client
self.retriever = retriever
self.ddgs = ddgs
self.full_response = ""
self.file_paths = [
"data/documents/dokumen.pdf",
]
async def load_documents(self):
for file_path in self.file_paths:
await self.add_doc(file_path)
async def add_doc(self, file_path):
result = await self.inferencer.retriever.add_document_from_file(file_path)
async def retrieve_context(self, query : str):
retrieval_result = await self.retriever.retrieve(query=query)
contexts = ""
for i, ctx in enumerate(retrieval_result.documents, 1):
contexts += f"{i}. {ctx.page_content}\n"
return contexts
async def web_search(self, query : str):
search_results = []
async for result in self.ddgs.search(query, max_results=5):
doc = Document(
page_content=result,
metadata={"source": "internet_search", "query": self.prompt}
)
search_results.append(doc)
await self.retriever.add_documents(search_results)
return search_results
def format_prompt(self, question : str, context : str, prompt_template):
prompt_template = copy.deepcopy(prompt_template)
formatted_prompt = []
for cht in prompt_template:
if("content" in cht.keys()):
cht["content"] = cht["content"].format(
question = question,
context = context
)
formatted_prompt.append(cht)
return formatted_prompt
async def get_result(self, req : AgentRequest):
response_start = datetime.now()
start_time = datetime.now()
chat_memory = req.chat_memory
if(req.enable_search):
self.web_search(req.question)
if(req.enable_retrieval):
contexts = await self.retrieve_context(req.question)
setup_time = (datetime.now() - start_time).total_seconds()
yield meta_data_response(
query = req.question,
setup_time = setup_time,
num_contexts = len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts),
enable_rerank = False
)
else:
contexts = "No Context"
chat_memory += self.format_prompt(req.question, contexts, req.prompt_template)
accumulated_text = ""
text_buffer = ""
try:
logging.info(chat_memory)
print("Chat Memory ", chat_memory)
response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=chat_memory,
max_tokens=200,
stream=True
)
response_start = datetime.now()
accumulated_text = ""
for stream_data in response:
delta = stream_data.choices[0].delta.content if stream_data.choices[0].delta else ""
finish_reason = stream_data.choices[0].finish_reason
if finish_reason == "stop":
total_time = (datetime.now() - start_time).total_seconds()
if text_buffer:
yield complete_response(
total_time=total_time,
accumulated_text=accumulated_text,
contexts=contexts
)
yield text_buffer
break
if delta:
self.full_response += delta
text_buffer += delta
chunk = delta
accumulated_text += chunk
yield chunk_response(
chunk=chunk,
accumulated_text=accumulated_text,
generation_times=(datetime.now() - response_start).total_seconds()
)
text_buffer = ""
except Exception as e:
print("Request", req)
logging.error(f"GPT Executor Failed: {e}")
yield error_response(e=str(e) + req, meta_data = f"Question is : {req.question}")