achuthc1298 commited on
Commit
8c4fcbf
·
verified ·
1 Parent(s): 3dc51bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -59
app.py CHANGED
@@ -1,63 +1,154 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ from pathlib import Path
4
+ from llama_index.core.selectors import LLMSingleSelector
5
+ from llama_index.core.tools import QueryEngineTool
6
+ from llama_index.core import VectorStoreIndex
7
+ from llama_index.core import Settings
8
+ from llama_index.core import SimpleDirectoryReader
9
+ from llama_index.llms.groq import Groq
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from typing import Tuple
12
+ from llama_index.core import StorageContext, load_index_from_storage
13
+ from llama_index.core.objects import ObjectIndex
14
+ from llama_index.core.agent import ReActAgent
15
+ import sys
16
+ import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Function to process files and create document tools
19
+ def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]:
20
+ documents = SimpleDirectoryReader(input_files=[document_fp]).load_data()
21
+
22
+ Settings.llm = Groq(model="mixtral-8x7b-32768")
23
+ Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
24
+
25
+ load_dir_path = f"/home/user/app/agentic_index/{doc_name}"
26
+ storage_context = StorageContext.from_defaults(persist_dir=load_dir_path)
27
+ vector_index = load_index_from_storage(storage_context)
28
+ vector_query_engine = vector_index.as_query_engine()
29
+
30
+ vector_tool = QueryEngineTool.from_defaults(
31
+ name=f"{doc_name}_vector_query_engine_tool",
32
+ query_engine=vector_query_engine,
33
+ description=f"Useful for retrieving specific context from the {doc_name}.",
34
+ )
35
+
36
+ return vector_tool
37
+
38
+ # Function to find and sort .tex files
39
+ def find_tex_files(directory: str):
40
+ tex_files = []
41
+ for root, dirs, files in os.walk(directory):
42
+ for file in files:
43
+ if file.endswith(('.tex', '.txt')):
44
+ file_path = os.path.abspath(os.path.join(root, file))
45
+ tex_files.append(file_path)
46
+ tex_files.sort()
47
+ return tex_files
48
+
49
+ def initialize_agent(apikey):
50
+ os.environ["GROQ_API_KEY"] = apikey
51
+
52
+ llm = Groq(model="mixtral-8x7b-32768")
53
+
54
+ try:
55
+ directory = '/home/user/app/rag_docs_final_review_tex_merged'
56
+ tex_files = find_tex_files(directory)
57
+
58
+ paper_to_tools_dict = {}
59
+ for paper in tex_files:
60
+ path = Path(paper)
61
+ vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path)
62
+ paper_to_tools_dict[path.stem] = [vector_tool]
63
+
64
+ initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]]
65
+
66
+ obj_index = ObjectIndex.from_objects(
67
+ initial_tools,
68
+ index_cls=VectorStoreIndex,
69
+ )
70
+
71
+ obj_retriever = obj_index.as_retriever(similarity_top_k=6)
72
+
73
+ context = """You are an agent designed to answer scientific queries over a set of given documents.
74
+ Please always use the tools provided to answer a question. Do not rely on prior knowledge.
75
+ """
76
+
77
+ agent = ReActAgent.from_tools(
78
+ tool_retriever=obj_retriever,
79
+ llm=llm,
80
+ verbose=True,
81
+ context=context
82
+ )
83
+ return agent
84
+ except Exception as e:
85
+ return str(e)
86
+
87
+ def chat_with_agent(prompt, agent, verbose_toggle):
88
+ try:
89
+ # Redirect stdout
90
+ original_stdout = sys.stdout
91
+ sys.stdout = io.StringIO()
92
+
93
+ # query the agent
94
+ response = agent.query(prompt)
95
+
96
+ # Get the captured output and restore stdout
97
+ output = sys.stdout.getvalue()
98
+ sys.stdout = original_stdout
99
+
100
+ # format the received verbose output
101
+ verbose = ''
102
+ for output_string in output.split('==='):
103
+ verbose += output_string
104
+ verbose += '\n'
105
+
106
+ # assistant response
107
+ msg = f'{verbose}' if verbose_toggle else f'{response.response[10:]}'
108
+
109
+ return msg
110
+ except Exception as e:
111
+ return str(e)
112
+
113
+ # Gradio Interface
114
+ def main():
115
+ # Initialize agent on startup
116
+ agent = None
117
+
118
+ def set_apikey(apikey):
119
+ nonlocal agent
120
+ agent = initialize_agent(apikey)
121
+ return "API Key Set. You may start asking questions now."
122
+
123
+ def reset_chat():
124
+ return "Chat reset. You may start asking questions now."
125
+
126
+ def chat_function(prompt, apikey, verbose_toggle):
127
+ if not agent:
128
+ set_apikey(apikey)
129
+ return chat_with_agent(prompt, agent, verbose_toggle)
130
+
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown("# AMGPT, Powered by LlamaIndex")
133
+
134
+ with gr.Row():
135
+ apikey = gr.Textbox(label="Enter your Groq API Key", type="password")
136
+ set_apikey_button = gr.Button("Set API Key")
137
+
138
+ set_apikey_button.click(set_apikey, inputs=apikey, outputs=None)
139
+
140
+ with gr.Row():
141
+ verbose_toggle = gr.Checkbox(label="Verbose", value=True)
142
+ reset = gr.Button("Reset Chat")
143
+
144
+ reset.click(reset_chat, outputs=None)
145
+
146
+ prompt = gr.Textbox(label="Ask a question")
147
+ output = gr.Textbox(label="Response")
148
+
149
+ prompt.submit(chat_function, inputs=[prompt, apikey, verbose_toggle], outputs=output)
150
+
151
+ demo.launch()
152
 
153
  if __name__ == "__main__":
154
+ main()