Spaces:
Build error
Build error
File size: 4,592 Bytes
257dcc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import logging
import os
import gradio as gr
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.llms import MessageRole
from llama_index.core.memory import ChatSummaryMemoryBuffer
from llama_index.core.tools import RetrieverTool, ToolMetadata
from llama_index.agent.openai import OpenAIAgent
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.core import Settings
from llama_index.core.postprocessor import LLMRerank
from utils import create_db, load_db, load_asset
from config import CHROMA_PATH, PLACEHOLDER, TITLE, PROMPT_SYSTEM_MESSAGE, TEXT_QA_TEMPLATE
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
API_KEY=""
token_count = 0
def create_knowledge_base_if_not_exists():
if not os.path.exists(CHROMA_PATH) or not os.listdir(CHROMA_PATH):
print("⚠️ ChromaDB not found. Creating DB...")
create_db()
def get_tools():
index = load_db()
vector_retriever = VectorIndexRetriever(
index=index,
similarity_top_k=15,
embed_model=Settings.embed_model,
use_async=True,
)
# Add LLMRerank for better retrieval
reranker = LLMRerank(
choice_batch_size=5,
top_n=3,
)
def retrieve_with_rerank(query):
retrieved_docs = vector_retriever.retrieve(query)
reranked_docs = reranker.postprocess(retrieved_docs)
return reranked_docs
tools = [
RetrieverTool(
# retriever=vector_retriever,
retriever=retrieve_with_rerank,
metadata=ToolMetadata(
name="LitleJS_related_resources",
description="Useful for info related to the LittleJS game development library. It gathers the info from local data.",
),
)
]
return tools
def set_api_key(key):
API_KEY=key
Settings.llm = OpenAI(temperature=0, model="gpt-4o-mini", api_key=API_KEY)
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
def generate_completion(query, history, memory, api_key):
logging.info(f"User query: {query}")
if not API_KEY:
set_api_key(api_key)
# Manage memory
chat_list = memory.get()
if len(chat_list) != 0:
user_index = [i for i, msg in enumerate(chat_list) if msg.role == MessageRole.USER]
if len(user_index) > len(history):
user_index_to_remove = user_index[len(history)]
chat_list = chat_list[:user_index_to_remove]
memory.set(chat_list)
logging.info(f"chat_history: {len(memory.get())} {memory.get()}")
logging.info(f"gradio_history: {len(history)} {history}")
# Create agent
tools = get_tools()
agent = OpenAIAgent.from_tools(
llm=Settings.llm,
memory=memory,
tools=tools,
system_prompt=PROMPT_SYSTEM_MESSAGE,
)
# Generate answer
completion = agent.stream_chat(query)
answer_str = ""
for token in completion.response_gen:
answer_str += token
global token_count
token_count += 1 # Update token count
yield answer_str
def launch_ui():
js=load_asset("./assets/chat.js")
with gr.Blocks(
title=TITLE,
fill_height=True,
analytics_enabled=True,
css=load_asset("./assets/style.css"),
js=load_asset("./assets/chat.js"),
) as demo:
api_key_input = gr.Textbox(
label="Enter your OpenAI API Key",
type="password",
placeholder="sk-...",
elem_classes="api_key_input"
)
memory_state = gr.State(
lambda: ChatSummaryMemoryBuffer.from_defaults(
token_limit=120000,
)
)
chatbot = gr.Chatbot(
scale=1,
placeholder=PLACEHOLDER,
type='messages',
show_label=False,
show_copy_button=True,
elem_classes="chatbox",
)
gr.ChatInterface(
fn=generate_completion,
chatbot=chatbot,
type='messages',
additional_inputs=[memory_state, api_key_input],
)
token_counter = gr.Button("Tokens Used: 0", elem_classes="token_counter")
demo.queue(default_concurrency_limit=64)
demo.launch(debug=True, favicon_path="./assets/favicon.png", share=False) # Set share=True to share the app online
if __name__ == "__main__":
create_knowledge_base_if_not_exists()
launch_ui()
|