import logging import os import gradio as gr from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.llms import MessageRole from llama_index.core.memory import ChatSummaryMemoryBuffer from llama_index.core.tools import RetrieverTool, ToolMetadata from llama_index.agent.openai import OpenAIAgent from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from llama_index.core import Settings from llama_index.core.postprocessor import LLMRerank from utils import create_db, load_db, load_asset from config import CHROMA_PATH, PLACEHOLDER, TITLE, PROMPT_SYSTEM_MESSAGE, TEXT_QA_TEMPLATE logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) API_KEY="" token_count = 0 def create_knowledge_base_if_not_exists(): if not os.path.exists(CHROMA_PATH) or not os.listdir(CHROMA_PATH): print("⚠️ ChromaDB not found. Creating DB...") create_db() def get_tools(): index = load_db() vector_retriever = VectorIndexRetriever( index=index, similarity_top_k=15, embed_model=Settings.embed_model, use_async=True, ) # Add LLMRerank for better retrieval reranker = LLMRerank( choice_batch_size=5, top_n=3, ) def retrieve_with_rerank(query): retrieved_docs = vector_retriever.retrieve(query) reranked_docs = reranker.postprocess(retrieved_docs) return reranked_docs tools = [ RetrieverTool( # retriever=vector_retriever, retriever=retrieve_with_rerank, metadata=ToolMetadata( name="LitleJS_related_resources", description="Useful for info related to the LittleJS game development library. It gathers the info from local data.", ), ) ] return tools def set_api_key(key): API_KEY=key Settings.llm = OpenAI(temperature=0, model="gpt-4o-mini", api_key=API_KEY) Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small") def generate_completion(query, history, memory, api_key): logging.info(f"User query: {query}") if not API_KEY: set_api_key(api_key) # Manage memory chat_list = memory.get() if len(chat_list) != 0: user_index = [i for i, msg in enumerate(chat_list) if msg.role == MessageRole.USER] if len(user_index) > len(history): user_index_to_remove = user_index[len(history)] chat_list = chat_list[:user_index_to_remove] memory.set(chat_list) logging.info(f"chat_history: {len(memory.get())} {memory.get()}") logging.info(f"gradio_history: {len(history)} {history}") # Create agent tools = get_tools() agent = OpenAIAgent.from_tools( llm=Settings.llm, memory=memory, tools=tools, system_prompt=PROMPT_SYSTEM_MESSAGE, ) # Generate answer completion = agent.stream_chat(query) answer_str = "" for token in completion.response_gen: answer_str += token global token_count token_count += 1 # Update token count yield answer_str def launch_ui(): js=load_asset("./assets/chat.js") with gr.Blocks( title=TITLE, fill_height=True, analytics_enabled=True, css=load_asset("./assets/style.css"), js=load_asset("./assets/chat.js"), ) as demo: api_key_input = gr.Textbox( label="Enter your OpenAI API Key", type="password", placeholder="sk-...", elem_classes="api_key_input" ) memory_state = gr.State( lambda: ChatSummaryMemoryBuffer.from_defaults( token_limit=120000, ) ) chatbot = gr.Chatbot( scale=1, placeholder=PLACEHOLDER, type='messages', show_label=False, show_copy_button=True, elem_classes="chatbox", ) gr.ChatInterface( fn=generate_completion, chatbot=chatbot, type='messages', additional_inputs=[memory_state, api_key_input], ) token_counter = gr.Button("Tokens Used: 0", elem_classes="token_counter") demo.queue(default_concurrency_limit=64) demo.launch(debug=True, favicon_path="./assets/favicon.png", share=False) # Set share=True to share the app online if __name__ == "__main__": create_knowledge_base_if_not_exists() launch_ui()