ProximileAdmin commited on
Commit
9164b43
·
verified ·
1 Parent(s): 0eead1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -92
app.py CHANGED
@@ -9,18 +9,12 @@ import time
9
  from typing import Dict, List, Optional
10
 
11
  ENDPOINT_URL = "https://api.hyperbolic.xyz/v1"
12
-
13
  OAI_API_KEY = os.getenv('HYPERBOLIC_XYZ_KEY')
14
-
15
  VERBOSE_SHELL = True
16
-
17
  todays_date_string = datetime.date.today().strftime("%d %B %Y")
18
 
19
-
20
  NAME_OF_SERVICE = "arXiv Paper Search"
21
- DESCRIPTION_OF_SERVICE = (
22
- "a service that searches and retrieves academic papers from arXiv based on various criteria"
23
- )
24
  PAPER_SEARCH_FUNCTION_NAME = "search_arxiv_papers"
25
 
26
  functions_list = [
@@ -33,8 +27,8 @@ functions_list = [
33
  "type": "object",
34
  "properties": {
35
  "query": {
36
- "type": "string", # function names for AI agents should be chosen carefully to avoid confusion
37
- "description": "Search query (e.g., 'deep learning', 'quantum computing')" # descriptions help the AI agent's LLM backend understand the function
38
  },
39
  "max_results": {
40
  "type": "integer",
@@ -69,27 +63,9 @@ After receiving the results back from a function (formatted as {{"name": functio
69
 
70
  If the user request does not necessitate a function call, simply respond to the user's query directly."""
71
 
72
- def search_arxiv_papers(
73
- query: str,
74
- max_results: int = 5,
75
- sort_by: str = 'relevance'
76
- ) -> Dict:
77
- """
78
- Search for papers on arXiv using their API.
79
-
80
- Args:
81
- query: Search query string
82
- max_results: Maximum number of results to return (default: 5)
83
- sort_by: Sorting criteria (default: 'relevance')
84
-
85
- Returns:
86
- Dictionary containing search results and metadata
87
- """
88
  try:
89
- # Construct the search query
90
  search_query = f'all:{query}'
91
-
92
- # Construct the API URL
93
  base_url = 'http://export.arxiv.org/api/query?'
94
  params = {
95
  'search_query': search_query,
@@ -100,12 +76,8 @@ def search_arxiv_papers(
100
  }
101
  query_string = '&'.join([f'{k}={urllib.parse.quote(str(v))}' for k, v in params.items()])
102
  url = base_url + query_string
103
-
104
- # Make the API request
105
  response = urllib.request.urlopen(url)
106
  feed = feedparser.parse(response.read().decode('utf-8'))
107
-
108
- # Process the results
109
  papers = []
110
  for entry in feed.entries:
111
  paper = {
@@ -118,16 +90,12 @@ def search_arxiv_papers(
118
  'primary_category': entry.tags[0]['term']
119
  }
120
  papers.append(paper)
121
-
122
- # Add a delay to respect API rate limits
123
  time.sleep(3)
124
-
125
  return {
126
  'status': 'success',
127
  'total_results': len(papers),
128
  'papers': papers
129
  }
130
-
131
  except Exception as e:
132
  return {
133
  'status': 'error',
@@ -136,7 +104,6 @@ def search_arxiv_papers(
136
 
137
  functions_dict = {f["function"]["name"]: f for f in functions_list}
138
  FUNCTION_BACKENDS = {
139
- #WALLET_CHECK_FUNCTION_NAME: check_wallet_balance,
140
  PAPER_SEARCH_FUNCTION_NAME: search_arxiv_papers,
141
  }
142
 
@@ -149,8 +116,6 @@ class LLM:
149
  self.api_key = OAI_API_KEY
150
  self.max_model_len = max_model_len
151
  self.client = OpenAI(base_url=ENDPOINT_URL, api_key=self.api_key)
152
- #models_list = self.client.models.list()
153
- #self.model_name = models_list.data[0].id
154
  self.model_name = "meta-llama/Llama-3.3-70B-Instruct"
155
 
156
  def generate(self, prompt: str, sampling_params: dict) -> dict:
@@ -163,18 +128,15 @@ class LLM:
163
  "n": sampling_params.get("n", 1),
164
  "stream": False,
165
  }
166
-
167
  if "stop" in sampling_params:
168
  completion_params["stop"] = sampling_params["stop"]
169
  if "presence_penalty" in sampling_params:
170
  completion_params["presence_penalty"] = sampling_params["presence_penalty"]
171
  if "frequency_penalty" in sampling_params:
172
  completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
173
-
174
  return self.client.completions.create(**completion_params)
175
 
176
  def form_chat_prompt(message_history, functions=functions_dict.keys()):
177
- """Builds the chat prompt for the LLM."""
178
  functions_string = "\n\n".join([json.dumps(functions_dict[f], indent=4) for f in functions])
179
  full_prompt = (
180
  ROLE_HEADER.format(role="system")
@@ -193,7 +155,6 @@ def form_chat_prompt(message_history, functions=functions_dict.keys()):
193
  return full_prompt
194
 
195
  def check_assistant_response_for_tool_calls(response):
196
- """Check if the LLM response contains a function call."""
197
  response = response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
198
  for tool_name in functions_dict.keys():
199
  if f"\"{tool_name}\"" in response and "{" in response:
@@ -207,21 +168,17 @@ def check_assistant_response_for_tool_calls(response):
207
  return None
208
 
209
  def process_tool_request(tool_request_data):
210
- """Process tool requests from the LLM."""
211
  tool_name = tool_request_data["name"]
212
  tool_parameters = tool_request_data["parameters"]
213
-
214
  if tool_name == PAPER_SEARCH_FUNCTION_NAME:
215
  query = tool_parameters["query"]
216
  max_results = tool_parameters.get("max_results", 5)
217
  sort_by = tool_parameters.get("sort_by", "relevance")
218
  search_results = FUNCTION_BACKENDS[tool_name](query, max_results, sort_by)
219
  return {"name": PAPER_SEARCH_FUNCTION_NAME, "results": search_results}
220
-
221
  return None
222
 
223
  def restore_message_history(full_history):
224
- """Restore the complete message history including tool interactions."""
225
  restored = []
226
  for message in full_history:
227
  if message["role"] == "assistant" and "metadata" in message:
@@ -239,13 +196,10 @@ def restore_message_history(full_history):
239
  return restored
240
 
241
  def iterate_chat(llm, sampling_params, full_history):
242
- """Handle conversation turns with tool calling."""
243
  tool_interactions = []
244
-
245
  for _ in range(10):
246
  prompt = form_chat_prompt(restore_message_history(full_history) + tool_interactions)
247
  output = llm.generate(prompt, sampling_params)
248
-
249
  if VERBOSE_SHELL:
250
  print(f"Input prompt: {prompt}")
251
  print("-" * 50)
@@ -253,10 +207,8 @@ def iterate_chat(llm, sampling_params, full_history):
253
  print("=" * 50)
254
  if not output or not output.choices:
255
  raise ValueError("Invalid completion response")
256
-
257
  assistant_response = output.choices[0].text.strip()
258
  assistant_response = assistant_response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
259
-
260
  tool_request_data = check_assistant_response_for_tool_calls(assistant_response)
261
  if not tool_request_data:
262
  final_message = {
@@ -275,58 +227,41 @@ def iterate_chat(llm, sampling_params, full_history):
275
  }
276
  tool_interactions.append(assistant_message)
277
  tool_return_data = process_tool_request(tool_request_data)
278
-
279
  tool_message = {
280
  "role": "function",
281
  "content": json.dumps(tool_return_data)
282
  }
283
  tool_interactions.append(tool_message)
284
-
285
  return full_history
286
 
287
- def user_conversation(user_message, chat_history, full_history):
288
- """Handle user input and maintain conversation state."""
289
- if full_history is None:
290
- full_history = []
291
-
292
- full_history.append({"role": "user", "content": user_message})
 
 
 
 
 
293
  updated_history = iterate_chat(llm, sampling_params, full_history)
294
  assistant_answer = updated_history[-1]["content"]
295
- chat_history.append((user_message, assistant_answer))
296
-
297
- return "", chat_history, updated_history
298
-
299
- sampling_params = {
300
- "temperature": 0.8,
301
- "top_p": 0.95,
302
- "max_tokens": 512,
303
- "stop_token_ids": [128001,128008,128009,128006],
304
- }
305
 
306
  # Initialize LLM
307
  llm = LLM(max_model_len=8096)
308
 
309
- with gr.Blocks() as demo:
310
- gr.Markdown(f"<h2>{NAME_OF_SERVICE}</h2>")
311
- chat_state = gr.State([])
312
- chatbot = gr.Chatbot(label="Chat with the arXiv Paper Search Assistant")
313
- user_input = gr.Textbox(
314
- lines=1,
315
- placeholder="Type your message here...",
316
- )
317
-
318
- user_input.submit(
319
- fn=user_conversation,
320
- inputs=[user_input, chatbot, chat_state],
321
- outputs=[user_input, chatbot, chat_state],
322
- queue=False
323
- )
324
 
325
- send_button = gr.Button("Send")
326
- send_button.click(
327
- fn=user_conversation,
328
- inputs=[user_input, chatbot, chat_state],
329
- outputs=[user_input, chatbot, chat_state],
330
- queue=False
331
- )
332
  demo.launch()
 
9
  from typing import Dict, List, Optional
10
 
11
  ENDPOINT_URL = "https://api.hyperbolic.xyz/v1"
 
12
  OAI_API_KEY = os.getenv('HYPERBOLIC_XYZ_KEY')
 
13
  VERBOSE_SHELL = True
 
14
  todays_date_string = datetime.date.today().strftime("%d %B %Y")
15
 
 
16
  NAME_OF_SERVICE = "arXiv Paper Search"
17
+ DESCRIPTION_OF_SERVICE = "a service that searches and retrieves academic papers from arXiv based on various criteria"
 
 
18
  PAPER_SEARCH_FUNCTION_NAME = "search_arxiv_papers"
19
 
20
  functions_list = [
 
27
  "type": "object",
28
  "properties": {
29
  "query": {
30
+ "type": "string",
31
+ "description": "Search query (e.g., 'deep learning', 'quantum computing')"
32
  },
33
  "max_results": {
34
  "type": "integer",
 
63
 
64
  If the user request does not necessitate a function call, simply respond to the user's query directly."""
65
 
66
+ def search_arxiv_papers(query: str, max_results: int = 5, sort_by: str = 'relevance') -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
 
68
  search_query = f'all:{query}'
 
 
69
  base_url = 'http://export.arxiv.org/api/query?'
70
  params = {
71
  'search_query': search_query,
 
76
  }
77
  query_string = '&'.join([f'{k}={urllib.parse.quote(str(v))}' for k, v in params.items()])
78
  url = base_url + query_string
 
 
79
  response = urllib.request.urlopen(url)
80
  feed = feedparser.parse(response.read().decode('utf-8'))
 
 
81
  papers = []
82
  for entry in feed.entries:
83
  paper = {
 
90
  'primary_category': entry.tags[0]['term']
91
  }
92
  papers.append(paper)
 
 
93
  time.sleep(3)
 
94
  return {
95
  'status': 'success',
96
  'total_results': len(papers),
97
  'papers': papers
98
  }
 
99
  except Exception as e:
100
  return {
101
  'status': 'error',
 
104
 
105
  functions_dict = {f["function"]["name"]: f for f in functions_list}
106
  FUNCTION_BACKENDS = {
 
107
  PAPER_SEARCH_FUNCTION_NAME: search_arxiv_papers,
108
  }
109
 
 
116
  self.api_key = OAI_API_KEY
117
  self.max_model_len = max_model_len
118
  self.client = OpenAI(base_url=ENDPOINT_URL, api_key=self.api_key)
 
 
119
  self.model_name = "meta-llama/Llama-3.3-70B-Instruct"
120
 
121
  def generate(self, prompt: str, sampling_params: dict) -> dict:
 
128
  "n": sampling_params.get("n", 1),
129
  "stream": False,
130
  }
 
131
  if "stop" in sampling_params:
132
  completion_params["stop"] = sampling_params["stop"]
133
  if "presence_penalty" in sampling_params:
134
  completion_params["presence_penalty"] = sampling_params["presence_penalty"]
135
  if "frequency_penalty" in sampling_params:
136
  completion_params["frequency_penalty"] = sampling_params["frequency_penalty"]
 
137
  return self.client.completions.create(**completion_params)
138
 
139
  def form_chat_prompt(message_history, functions=functions_dict.keys()):
 
140
  functions_string = "\n\n".join([json.dumps(functions_dict[f], indent=4) for f in functions])
141
  full_prompt = (
142
  ROLE_HEADER.format(role="system")
 
155
  return full_prompt
156
 
157
  def check_assistant_response_for_tool_calls(response):
 
158
  response = response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
159
  for tool_name in functions_dict.keys():
160
  if f"\"{tool_name}\"" in response and "{" in response:
 
168
  return None
169
 
170
  def process_tool_request(tool_request_data):
 
171
  tool_name = tool_request_data["name"]
172
  tool_parameters = tool_request_data["parameters"]
 
173
  if tool_name == PAPER_SEARCH_FUNCTION_NAME:
174
  query = tool_parameters["query"]
175
  max_results = tool_parameters.get("max_results", 5)
176
  sort_by = tool_parameters.get("sort_by", "relevance")
177
  search_results = FUNCTION_BACKENDS[tool_name](query, max_results, sort_by)
178
  return {"name": PAPER_SEARCH_FUNCTION_NAME, "results": search_results}
 
179
  return None
180
 
181
  def restore_message_history(full_history):
 
182
  restored = []
183
  for message in full_history:
184
  if message["role"] == "assistant" and "metadata" in message:
 
196
  return restored
197
 
198
  def iterate_chat(llm, sampling_params, full_history):
 
199
  tool_interactions = []
 
200
  for _ in range(10):
201
  prompt = form_chat_prompt(restore_message_history(full_history) + tool_interactions)
202
  output = llm.generate(prompt, sampling_params)
 
203
  if VERBOSE_SHELL:
204
  print(f"Input prompt: {prompt}")
205
  print("-" * 50)
 
207
  print("=" * 50)
208
  if not output or not output.choices:
209
  raise ValueError("Invalid completion response")
 
210
  assistant_response = output.choices[0].text.strip()
211
  assistant_response = assistant_response.split(FUNCTION_EOT_STRING)[0].split(EOT_STRING)[0]
 
212
  tool_request_data = check_assistant_response_for_tool_calls(assistant_response)
213
  if not tool_request_data:
214
  final_message = {
 
227
  }
228
  tool_interactions.append(assistant_message)
229
  tool_return_data = process_tool_request(tool_request_data)
 
230
  tool_message = {
231
  "role": "function",
232
  "content": json.dumps(tool_return_data)
233
  }
234
  tool_interactions.append(tool_message)
 
235
  return full_history
236
 
237
+ def respond(message, chat_history, system_message, max_tokens, temperature, top_p):
238
+ if chat_history is None:
239
+ chat_history = []
240
+ full_history = chat_history.copy()
241
+ full_history.append({"role": "user", "content": message})
242
+ sampling_params = {
243
+ "temperature": temperature,
244
+ "top_p": top_p,
245
+ "max_tokens": max_tokens,
246
+ "stop_token_ids": [128001, 128008, 128009, 128006],
247
+ }
248
  updated_history = iterate_chat(llm, sampling_params, full_history)
249
  assistant_answer = updated_history[-1]["content"]
250
+ chat_history.append((message, assistant_answer))
251
+ return chat_history
 
 
 
 
 
 
 
 
252
 
253
  # Initialize LLM
254
  llm = LLM(max_model_len=8096)
255
 
256
+ demo = gr.ChatInterface(
257
+ respond,
258
+ additional_inputs=[
259
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
260
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
261
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
262
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
263
+ ],
264
+ )
 
 
 
 
 
 
265
 
266
+ if __name__ == "__main__":
 
 
 
 
 
 
267
  demo.launch()