File size: 4,464 Bytes
cbdf795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e6f47
cbdf795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6811a
cbdf795
 
 
 
 
 
 
 
4f6811a
 
cbdf795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6811a
 
cbdf795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import logging
import os
import sys
from threading import Lock

import gradio as gr
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
                         load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url


def get_embed_model():
    model_kwargs = {'device': 'cpu'}
    if torch.cuda.is_available():
      model_kwargs['device'] = 'cuda'
    if torch.backends.mps.is_available():
      model_kwargs['device'] = 'mps'

    encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
    print("Loading model...")
    try:
      model_norm = HuggingFaceEmbeddings(
        model_name="thenlper/gte-small",
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
      )
    except Exception as exception:
      print(f"Model not found. Loading fake model...{exception}")
      exit()
    print("Model loaded.")
    return model_norm

embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)

s3 = s3fs.S3FileSystem(
  key=os.environ["AWS_CANONICAL_KEY"],
  secret=os.environ["AWS_CANONICAL_SECRET"],
)

titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))

agents = {}
for title in titles:
  if(title == "vector_index"):
    continue

  print(title)
  # build vector index
  storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
  vector_index = load_index_from_storage(storage_context)

  # define query engines
  vector_query_engine = vector_index.as_query_engine(
    similarity_top_k=2,
    verbose=True
  )
  agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
    "vector",
    retriever_dict={"vector": vector_retriever},
    query_engine_dict=agents,
    verbose=True,
    query_response_tmpl="{response}"
)

lock = Lock()

def predict(message):
  print(message)
  lock.acquire()
  try:
    output = recursive_retriever.retrieve(message)[0]
    output = output.get_text()
  except Exception as e:
    print(e)
    raise e
  finally:
    lock.release()
  return output

def getanswer(question, history):
  print("getting answer")
  if hasattr(history, "value"):
    history = history.value
  if hasattr(question, "value"):
    question = question.value

  history = history or []
  lock.acquire()
  try:
    output = recursive_retriever.retrieve(question)[0]
    history.append((question, output.get_text()))
  except Exception as e:
    raise e
  finally:
    lock.release()
  return history, history, gr.update(value="")

with gr.Blocks() as demo:
  with gr.Row():
    with gr.Column(scale=0.75):
      with gr.Row():
        gr.Markdown("<h1>F150 User Manual</h1>")
      chatbot = gr.Chatbot(elem_id="chatbot").style(height=600)

      with gr.Row():
          message = gr.Textbox(
              label="",
              placeholder="F150 User Manual",
              lines=1,
          )
      with gr.Row():
          submit = gr.Button(value="Send", variant="primary", scale=1)

      state = gr.State()
      submit.click(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
      message.submit(getanswer, inputs=[message, state], outputs=[chatbot, state, message])

      predictBtn = gr.Button(value="Predict", visible=False)
      predictBtn.click(predict, inputs=[message], outputs=[message])

demo.launch(debug=True)