Nelly-43 commited on
Commit
cdc69e6
·
verified ·
1 Parent(s): b0edb89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -50
app.py CHANGED
@@ -1,70 +1,245 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
42
 
 
 
 
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
  """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
  with gr.Blocks() as demo:
64
  with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ import glob
2
  import gradio as gr
3
+ from langchain_community.document_loaders.csv_loader import CSVLoader
4
+ from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader
5
+ from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import Chroma
8
+ from huggingface_hub import snapshot_download, upload_folder
9
 
10
+ from langchain.tools import tool
11
+ from langchain.agents import create_agent
12
+ from langchain.agents.middleware import dynamic_prompt, ModelRequest
13
 
14
+ snapshot_download(repo_id="CGIAR/weai-ref",
15
+ repo_type="dataset",
16
+ token=os.getenv('HF_TOKEN'),
17
+ local_dir='./refs'
18
+ )
 
 
 
 
 
 
 
 
19
 
20
+ snapshot_download(repo_id="CGIAR/weai-docs",
21
+ repo_type="dataset",
22
+ token=os.getenv('HF_TOKEN'),
23
+ local_dir='./docs'
24
+ )
25
+ warnings.filterwarnings('ignore')
26
+ os.environ["WANDB_DISABLED"] = "true"
27
 
28
+ repo_id = "meta-llama/Llama-3.3-70B-Instruct"
29
 
30
+ model = HuggingFaceEndpoint(
31
+ task='conversational',
32
+ repo_id = repo_id,
33
+ temperature = 0.5,
34
+ huggingfacehub_api_token=os.getenv('HF_TOKEN'),
35
+ max_new_tokens = 1500,
36
+ )
37
+
38
+ chat_llm = ChatHuggingFace(llm=model, verbose=True)
39
+
40
+ model_name = "sentence-transformers/all-mpnet-base-v2"
41
+ model_kwargs = {"device": "cuda"}
42
+
43
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
44
+
45
+ def docs_return(directory_path, flag):
46
+ docx_file_pattern = '*.docx'
47
+ pdf_file_pattern = '*.pdf'
48
+ txt_file_pattern = '*.txt'
49
+
50
+ docx_file_paths = glob.glob(directory_path + docx_file_pattern)
51
+ pdf_file_paths = glob.glob(directory_path + pdf_file_pattern)
52
+ txt_file_paths = glob.glob(directory_path + txt_file_pattern)
53
+
54
+ all_doc, all_doc2 = [], []
55
+
56
+ for x in docx_file_paths:
57
+ loader = Docx2txtLoader(x)
58
+ documents = loader.load()
59
+ all_doc.extend(documents)
60
+ all_doc2.append(str(documents[0].page_content))
61
+
62
+ for x in pdf_file_paths:
63
+ loader = PyPDFLoader(x, extract_images=True)
64
+ docs_lazy = loader.lazy_load()
65
+ documents = []
66
+ for doc in docs_lazy:
67
+ documents.append(doc)
68
+ all_doc.extend(documents)
69
+ all_doc2.append(str(documents[0].page_content))
70
+
71
+ for x in txt_file_paths:
72
+ loader = TextLoader(x)
73
+ documents = loader.load()
74
+ all_doc.extend(documents)
75
+ all_doc2.append(str(documents[0].page_content))
76
+
77
+ docs = '\n\n'.join(all_doc2)
78
+
79
+ return all_doc if flag == 0 else docs
80
+
81
+ def get_text_splitter(splitter_type='character',
82
+ chunk_size=500,
83
+ chunk_overlap=30,
84
+ separator="\n",
85
+ max_tokens=1000):
86
+ if splitter_type == 'character':
87
+ return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)
88
+ elif splitter_type == 'recursive':
89
+ return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
90
+ elif splitter_type == 'token':
91
+ return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap)
92
+ else:
93
+ raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.")
94
+
95
+ splitter_type='character'
96
+ chunk_size=1500
97
+ chunk_overlap=30
98
+ separator="\n"
99
+ max_tokens=1000
100
+ docs_path = "./docs/"
101
+
102
+ all_doc = docs_return(docs_path, 0)
103
+
104
+ # Use the splitter parameters
105
+ text_splitter = get_text_splitter(splitter_type=splitter_type,
106
+ chunk_size=chunk_size,
107
+ chunk_overlap=chunk_overlap,
108
+ separator=separator,
109
+ max_tokens=max_tokens)
110
+
111
+ # Split the documents using the text splitter
112
+ docs = text_splitter.split_documents(documents=all_doc)
113
+
114
+ # Create a Chroma vector database
115
+ docs_vector_db = Chroma.from_documents(docs, embeddings, persist_directory="chroma_data")
116
+
117
+ REFS_CSV_PATH = f"{DATA_DIR}/WEAI reference list - Sheet1.csv"
118
+ REFS_CHROMA_PATH = "./refs/"
119
+
120
+ loader = CSVLoader(file_path=REFS_CSV_PATH,
121
+ source_column="Description (what it contains and what it's useful for)")
122
+ refs = loader.load()
123
 
124
+ refs_vector_db = Chroma.from_documents(
125
+ refs, embeddings, persist_directory=REFS_CHROMA_PATH
126
+ )
127
+
128
+ @dynamic_prompt
129
+ def ref_context(request: ModelRequest) -> str:
130
+ """Inject context into state messages."""
131
+ last_query = request.state["messages"][-1].text
132
+ ref_content = refs_vector_db.as_retriever(k=10)
133
+
134
+ system_message = (
135
+ """Your job is to use relevant links and email addresses to
136
+ direct users to in order to reach and contact the WEAI team. If you don't know
137
+ an answer, say you don't know. Do not state that you are referring to the
138
+ provided context and respond as if you were in charge of the WEAI helpdesk."""
139
+ f"\n\n{ref_content}"
140
+ )
141
+
142
+ return system_message
143
+
144
+ contact_agent = (create_agent(chat, tools=[], middleware=[ref_context]))
145
+
146
+ @tool("contact", description="refer users to WEAI team using links and contact details")
147
+ def call_contact_agent(query: str):
148
+ result = contact_agent.invoke({"messages": [{"role": "user", "content": query}]})
149
+ return result["messages"][-1].content
150
+
151
+
152
+ @dynamic_prompt
153
+ def doc_context(request: ModelRequest) -> str:
154
+ """Inject context into state messages."""
155
+ last_query = request.state["messages"][-1].text
156
+ doc_content = docs_vector_db.as_retriever(k=10)
157
+
158
+ system_message = (
159
+ """Your job is to use resources from the International Food
160
+ Policy Research Institute to answer questions about women empowerment in agriculture.
161
+ Use the following context to answer questions. Be as detailed
162
+ as possible, but don't make up any information that's not
163
+ from the context and where possible reference related studies and resources as examples
164
+ from the context you have. If you don't know an answer, say you don't know.
165
+ Be concise but thorough in your response and try not to exceed the output token limit.
166
+ Do not state that you are referring to the provided context and respond
167
+ as if you were in charge of the WEAI helpdesk. """
168
+ f"\n\n{doc_content}"
169
+ )
170
+
171
+ return system_message
172
+
173
+ support_agent = (create_agent(chat, tools=[call_contact_agent], middleware=[doc_context]))
174
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ @tool("support", description="respond to user queries using context in WEAI docs")
177
+ def call_support_agent(query: str):
178
+ result = support_agent.invoke({"messages": [{"role": "user", "content": query}]})
179
+ return result["messages"][-1].content
180
+
181
+ support_instructions = """
182
+ You are in charge of the WEAI helpdesk.
183
+ Your job is to answer user queries using provided context and references
184
+ and refer users to WEAI personnel as well as relevant resource links where necessary.
185
 
186
+ Steps:
187
+ 1. Use the support tool to answer queries to the best of your knowledge.
188
+ 2. If no contact information or links are provided in the response, use the
189
+ contact tool to add all relevant contact and resource information to the response.
190
+ 3. Return only a complete response with included contact and resource information.
191
+ """
192
 
193
+ response_agent = create_agent(model=chat,
194
+ tools=[call_contact_agent, call_support_agent],
195
+ system_prompt=support_instructions,
196
+ )
197
  """
198
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
199
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  with gr.Blocks() as demo:
201
  with gr.Sidebar():
202
+ gr.LoginButton()
203
+ gr.Markdown("# WEAI-bot")
204
+ chatbot = gr.Chatbot(type='messages',
205
+ allow_tags=True)
206
+ msg = gr.Textbox()
207
+ clear = gr.ClearButton([msg, chatbot])
208
+
209
+ def handle_undo(history, undo_data: gr.UndoData):
210
+ return history[:undo_data.index], history[undo_data.index]['content'][0]["text"]
211
+
212
+ def handle_retry(history, retry_data: gr.RetryData):
213
+ new_history = history[:retry_data.index]
214
+ previous_prompt = history[retry_data.index]['content'][0]["text"]
215
+ yield from respond(previous_prompt, new_history)
216
+
217
+ def support_agent_fn(message, history):
218
+ result = support_agent.invoke({"messages": [{"role": "user", "content": message}]})
219
+
220
+ response = result['messages'][-1].content#.split('<|start_header_id|>assistant<|end_header_id|>')[-1].strip()
221
+ history.append({"role": "user", "content": message})
222
+ history.append({"role": "assistant", "content": response})
223
+
224
+ return response, history
225
+
226
+ def handle_like(data: gr.LikeData):
227
+ if data.liked:
228
+ print("You upvoted this response: ", data.value)
229
+ else:
230
+ print("You downvoted this response: ", data.value)
231
+
232
+ def handle_edit(history, edit_data: gr.EditData):
233
+ new_history = history[:edit_data.index]
234
+ new_history[-1]['content'] = [{"text": edit_data.value, "type": "text"}]
235
+ return new_history
236
+
237
+ msg.submit(support_agent_fn, [msg, chatbot], [msg, chatbot])
238
 
239
+ chatbot.undo(handle_undo, chatbot, [chatbot, msg])
240
+ chatbot.retry(handle_retry, chatbot, chatbot)
241
+ chatbot.like(handle_like, None, None)
242
+ chatbot.edit(handle_edit, chatbot, chatbot)
243
 
244
  if __name__ == "__main__":
245
  demo.launch()