Spaces:
Configuration error
Configuration error
| import logging | |
| from datetime import datetime | |
| from src.internal.agents.base_agents import Agent, AgentRequest | |
| from src.internal.rag.inference.inferencer import Inferencer | |
| from src.internal.rag.retriever.langchain_retriever import LangChainRetriever | |
| from src.internal.rag.web_search.duckduckgo_search import DDGS | |
| from typing import List, Dict | |
| from langchain_core.documents import Document | |
| import copy | |
| from src.internal.rag.inference.inferencer_types import ( | |
| chunk_response, error_response, complete_response, meta_data_response | |
| ) | |
| from openai import OpenAI | |
| class GPTExecutorAgent(Agent): | |
| def __init__(self, ddgs : DDGS, retriever : LangChainRetriever, openai_client : OpenAI): | |
| super().__init__(None) | |
| self.openai_client = openai_client | |
| self.retriever = retriever | |
| self.ddgs = ddgs | |
| self.full_response = "" | |
| self.file_paths = [ | |
| "data/documents/dokumen.pdf", | |
| ] | |
| async def load_documents(self): | |
| for file_path in self.file_paths: | |
| await self.add_doc(file_path) | |
| async def add_doc(self, file_path): | |
| result = await self.inferencer.retriever.add_document_from_file(file_path) | |
| async def retrieve_context(self, query : str): | |
| retrieval_result = await self.retriever.retrieve(query=query) | |
| contexts = "" | |
| for i, ctx in enumerate(retrieval_result.documents, 1): | |
| contexts += f"{i}. {ctx.page_content}\n" | |
| return contexts | |
| async def web_search(self, query : str): | |
| search_results = [] | |
| async for result in self.ddgs.search(query, max_results=5): | |
| doc = Document( | |
| page_content=result, | |
| metadata={"source": "internet_search", "query": self.prompt} | |
| ) | |
| search_results.append(doc) | |
| await self.retriever.add_documents(search_results) | |
| return search_results | |
| def format_prompt(self, question : str, context : str, prompt_template): | |
| prompt_template = copy.deepcopy(prompt_template) | |
| formatted_prompt = [] | |
| for cht in prompt_template: | |
| if("content" in cht.keys()): | |
| cht["content"] = cht["content"].format( | |
| question = question, | |
| context = context | |
| ) | |
| formatted_prompt.append(cht) | |
| return formatted_prompt | |
| async def get_result(self, req : AgentRequest): | |
| response_start = datetime.now() | |
| start_time = datetime.now() | |
| chat_memory = req.chat_memory | |
| if(req.enable_search): | |
| self.web_search(req.question) | |
| if(req.enable_retrieval): | |
| contexts = await self.retrieve_context(req.question) | |
| setup_time = (datetime.now() - start_time).total_seconds() | |
| yield meta_data_response( | |
| query = req.question, | |
| setup_time = setup_time, | |
| num_contexts = len(contexts.documents) if hasattr(contexts, 'documents') else len(contexts), | |
| enable_rerank = False | |
| ) | |
| else: | |
| contexts = "No Context" | |
| chat_memory += self.format_prompt(req.question, contexts, req.prompt_template) | |
| accumulated_text = "" | |
| text_buffer = "" | |
| try: | |
| logging.info(chat_memory) | |
| print("Chat Memory ", chat_memory) | |
| response = self.openai_client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=chat_memory, | |
| max_tokens=200, | |
| stream=True | |
| ) | |
| response_start = datetime.now() | |
| accumulated_text = "" | |
| for stream_data in response: | |
| delta = stream_data.choices[0].delta.content if stream_data.choices[0].delta else "" | |
| finish_reason = stream_data.choices[0].finish_reason | |
| if finish_reason == "stop": | |
| total_time = (datetime.now() - start_time).total_seconds() | |
| if text_buffer: | |
| yield complete_response( | |
| total_time=total_time, | |
| accumulated_text=accumulated_text, | |
| contexts=contexts | |
| ) | |
| yield text_buffer | |
| break | |
| if delta: | |
| self.full_response += delta | |
| text_buffer += delta | |
| chunk = delta | |
| accumulated_text += chunk | |
| yield chunk_response( | |
| chunk=chunk, | |
| accumulated_text=accumulated_text, | |
| generation_times=(datetime.now() - response_start).total_seconds() | |
| ) | |
| text_buffer = "" | |
| except Exception as e: | |
| print("Request", req) | |
| logging.error(f"GPT Executor Failed: {e}") | |
| yield error_response(e=str(e) + req, meta_data = f"Question is : {req.question}") | |