Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import os | |
| import sys | |
| from threading import Lock | |
| import gradio as gr | |
| import s3fs | |
| import torch | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| from llama_index import (ServiceContext, StorageContext, | |
| load_index_from_storage, set_global_service_context) | |
| from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent | |
| from llama_index.indices.vector_store import VectorStoreIndex | |
| from llama_index.llms import ChatMessage, MessageRole, OpenAI | |
| from llama_index.prompts import ChatPromptTemplate, PromptTemplate | |
| from llama_index.query_engine import RetrieverQueryEngine | |
| from llama_index.response_synthesizers import get_response_synthesizer | |
| from llama_index.retrievers import RecursiveRetriever | |
| from llama_index.tools import QueryEngineTool, ToolMetadata | |
| from llama_index.vector_stores import PGVectorStore | |
| from sqlalchemy import make_url | |
| def get_embed_model(): | |
| model_kwargs = {'device': 'cpu'} | |
| if torch.cuda.is_available(): | |
| model_kwargs['device'] = 'cuda' | |
| if torch.backends.mps.is_available(): | |
| model_kwargs['device'] = 'mps' | |
| encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity | |
| print("Loading model...") | |
| try: | |
| model_norm = HuggingFaceEmbeddings( | |
| model_name="thenlper/gte-small", | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| ) | |
| except Exception as exception: | |
| print(f"Model not found. Loading fake model...{exception}") | |
| exit() | |
| print("Model loaded.") | |
| return model_norm | |
| embed_model = get_embed_model() | |
| llm = OpenAI("gpt-4") | |
| service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model) | |
| set_global_service_context(service_context) | |
| s3 = s3fs.S3FileSystem( | |
| key=os.environ["AWS_CANONICAL_KEY"], | |
| secret=os.environ["AWS_CANONICAL_SECRET"], | |
| ) | |
| titles = s3.ls("f150-user-manual/recursive-agent/") | |
| titles = list(map(lambda x: x.split("/")[-1], titles)) | |
| agents = {} | |
| for title in titles: | |
| if(title == "vector_index"): | |
| continue | |
| print(title) | |
| # build vector index | |
| storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3) | |
| vector_index = load_index_from_storage(storage_context) | |
| # define query engines | |
| vector_query_engine = vector_index.as_query_engine( | |
| similarity_top_k=2, | |
| verbose=True | |
| ) | |
| agents[title] = vector_query_engine | |
| print(f"Agents: {len(agents)}") | |
| storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3) | |
| top_level_vector_index = load_index_from_storage(storage_context) | |
| vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1) | |
| recursive_retriever = RecursiveRetriever( | |
| "vector", | |
| retriever_dict={"vector": vector_retriever}, | |
| query_engine_dict=agents, | |
| verbose=True, | |
| query_response_tmpl="{response}" | |
| ) | |
| lock = Lock() | |
| def predict(message): | |
| print(message) | |
| lock.acquire() | |
| try: | |
| output = recursive_retriever.retrieve(message)[0] | |
| output = output.get_text() | |
| except Exception as e: | |
| print(e) | |
| raise e | |
| finally: | |
| lock.release() | |
| return output | |
| def getanswer(question, history): | |
| print("getting answer") | |
| if hasattr(history, "value"): | |
| history = history.value | |
| if hasattr(question, "value"): | |
| question = question.value | |
| history = history or [] | |
| lock.acquire() | |
| try: | |
| output = recursive_retriever.retrieve(question)[0] | |
| history.append((question, output.get_text())) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| lock.release() | |
| return history, history, gr.update(value="") | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=0.75): | |
| with gr.Row(): | |
| gr.Markdown("<h1>F150 User Manual</h1>") | |
| chatbot = gr.Chatbot(elem_id="chatbot").style(height=600) | |
| with gr.Row(): | |
| message = gr.Textbox( | |
| label="", | |
| placeholder="F150 User Manual", | |
| lines=1, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button(value="Send", variant="primary", scale=1) | |
| state = gr.State() | |
| submit.click(getanswer, inputs=[message, state], outputs=[chatbot, state, message]) | |
| message.submit(getanswer, inputs=[message, state], outputs=[chatbot, state, message]) | |
| predictBtn = gr.Button(value="Predict", visible=False) | |
| predictBtn.click(predict, inputs=[message], outputs=[message]) | |
| demo.launch(debug=True) |