Spaces:
Runtime error
Runtime error
File size: 4,464 Bytes
cbdf795 f3e6f47 cbdf795 4f6811a cbdf795 4f6811a cbdf795 4f6811a cbdf795 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import json
import logging
import os
import sys
from threading import Lock
import gradio as gr
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url
def get_embed_model():
model_kwargs = {'device': 'cpu'}
if torch.cuda.is_available():
model_kwargs['device'] = 'cuda'
if torch.backends.mps.is_available():
model_kwargs['device'] = 'mps'
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
print("Loading model...")
try:
model_norm = HuggingFaceEmbeddings(
model_name="thenlper/gte-small",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
except Exception as exception:
print(f"Model not found. Loading fake model...{exception}")
exit()
print("Model loaded.")
return model_norm
embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
s3 = s3fs.S3FileSystem(
key=os.environ["AWS_CANONICAL_KEY"],
secret=os.environ["AWS_CANONICAL_SECRET"],
)
titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))
agents = {}
for title in titles:
if(title == "vector_index"):
continue
print(title)
# build vector index
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
vector_index = load_index_from_storage(storage_context)
# define query engines
vector_query_engine = vector_index.as_query_engine(
similarity_top_k=2,
verbose=True
)
agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever},
query_engine_dict=agents,
verbose=True,
query_response_tmpl="{response}"
)
lock = Lock()
def predict(message):
print(message)
lock.acquire()
try:
output = recursive_retriever.retrieve(message)[0]
output = output.get_text()
except Exception as e:
print(e)
raise e
finally:
lock.release()
return output
def getanswer(question, history):
print("getting answer")
if hasattr(history, "value"):
history = history.value
if hasattr(question, "value"):
question = question.value
history = history or []
lock.acquire()
try:
output = recursive_retriever.retrieve(question)[0]
history.append((question, output.get_text()))
except Exception as e:
raise e
finally:
lock.release()
return history, history, gr.update(value="")
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.75):
with gr.Row():
gr.Markdown("<h1>F150 User Manual</h1>")
chatbot = gr.Chatbot(elem_id="chatbot").style(height=600)
with gr.Row():
message = gr.Textbox(
label="",
placeholder="F150 User Manual",
lines=1,
)
with gr.Row():
submit = gr.Button(value="Send", variant="primary", scale=1)
state = gr.State()
submit.click(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
message.submit(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
predictBtn = gr.Button(value="Predict", visible=False)
predictBtn.click(predict, inputs=[message], outputs=[message])
demo.launch(debug=True) |