lamhieu commited on
Commit
bf7aef8
·
1 Parent(s): b54a72a

chore: support tools with search on internet

Browse files
Files changed (2) hide show
  1. app.py +241 -49
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # pylint: skip-file
2
 
3
  import subprocess
 
 
4
 
5
  subprocess.run(
6
  f"pip install flash-attn --no-build-isolation",
@@ -15,7 +17,11 @@ from typing import Iterator
15
  import gradio as gr
16
  import spaces
17
  import torch
 
 
18
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
19
 
20
 
21
  MAX_MAX_NEW_TOKENS = 4096
@@ -25,13 +31,12 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
25
  DESCRIPTION = """\
26
  # Playground with Ghost 8B Beta (β, 8k)
27
 
28
- **Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
29
-
30
- The Ghost 8B Beta model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/).
31
 
32
  The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
33
 
34
- 📋 Note: current model version is "disl-0x5" (10 Jul 2024), context length 8k (8192 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
 
35
  """
36
 
37
 
@@ -250,88 +255,274 @@ if not torch.cuda.is_available():
250
 
251
  if torch.cuda.is_available():
252
  model_id = "ghost-x/ghost-8b-beta"
253
- model_tk = os.getenv("HF_TOKEN", None)
254
  model = AutoModelForCausalLM.from_pretrained(
255
  model_id,
256
  device_map="auto",
257
  torch_dtype=torch.bfloat16,
258
  attn_implementation="flash_attention_2",
259
  trust_remote_code=True,
260
- token=model_tk,
261
  )
262
  tokenizer = AutoTokenizer.from_pretrained(
263
  model_id,
264
  trust_remote_code=True,
265
- token=model_tk,
266
  )
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  @spaces.GPU(duration=120)
270
  def generate(
271
  message: str,
272
  chat_history: list[tuple[str, str]],
273
- system_prompt: str,
 
274
  max_new_tokens: int = 1536,
275
  temperature: float = 0.4,
276
  top_p: float = 0.95,
277
  top_k: int = 50,
278
  repetition_penalty: float = 1.0,
279
  ) -> Iterator[str]:
280
- conversation = []
281
- if system_prompt:
282
- conversation.append({"role": "system", "content": system_prompt})
283
- for user, assistant in chat_history:
284
- conversation.extend(
285
- [
286
- {"role": "user", "content": user},
287
- {"role": "assistant", "content": assistant},
288
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  )
290
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
291
 
292
- input_ids = tokenizer.apply_chat_template(
293
- conversation, add_generation_prompt=True, return_tensors="pt"
294
- )
295
- input_ids = input_ids.to(model.device)
296
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
297
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
298
- gr.Warning(
299
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
 
 
 
 
 
 
301
 
302
- streamer = TextIteratorStreamer(
303
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
304
- )
305
- generate_kwargs = dict(
306
- input_ids=input_ids,
307
- streamer=streamer,
308
- max_new_tokens=max_new_tokens,
309
- do_sample=True,
310
- repetition_penalty=repetition_penalty,
311
- )
312
- if temperature == 0:
313
- generate_kwargs["do_sample"] = False
314
- else:
315
- generate_kwargs["temperature"] = temperature
316
- generate_kwargs["top_p"] = top_p
317
- generate_kwargs["top_k"] = top_k
318
 
319
- t = Thread(target=model.generate, kwargs=generate_kwargs)
320
- t.start()
 
 
 
 
 
 
 
 
 
 
321
 
322
- outputs = []
323
- for text in streamer:
324
- outputs.append(text)
325
- yield "".join(outputs)
 
 
 
326
 
 
327
 
328
- chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
 
 
 
329
 
330
  chat_interface = gr.ChatInterface(
331
  fn=generate,
332
  chatbot=chatbot,
333
  fill_height=True,
334
  additional_inputs=[
 
 
 
335
  gr.Textbox(label="System prompt", lines=6),
336
  gr.Slider(
337
  label="Max new tokens",
@@ -373,6 +564,7 @@ chat_interface = gr.ChatInterface(
373
  cache_examples=False,
374
  examples=EXAMPLES,
375
  examples_per_page=9,
 
376
  )
377
 
378
  with gr.Blocks(fill_height=True, css="style.css") as demo:
 
1
  # pylint: skip-file
2
 
3
  import subprocess
4
+ import json
5
+ import requests
6
 
7
  subprocess.run(
8
  f"pip install flash-attn --no-build-isolation",
 
17
  import gradio as gr
18
  import spaces
19
  import torch
20
+ import wikipedia
21
+ import time
22
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
23
+ from bs4 import BeautifulSoup
24
+ from functools import lru_cache
25
 
26
 
27
  MAX_MAX_NEW_TOKENS = 4096
 
31
  DESCRIPTION = """\
32
  # Playground with Ghost 8B Beta (β, 8k)
33
 
34
+ **Ghost 8B Beta** model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/). The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
 
 
35
 
36
  The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
37
 
38
+ 🗞️ **Updates**
39
+ * Jul 23, 2024: added support for tools, now available to search for information on the internet.
40
  """
41
 
42
 
 
255
 
256
  if torch.cuda.is_available():
257
  model_id = "ghost-x/ghost-8b-beta"
258
+ hf_serect = os.getenv("HF_TOKEN", None)
259
  model = AutoModelForCausalLM.from_pretrained(
260
  model_id,
261
  device_map="auto",
262
  torch_dtype=torch.bfloat16,
263
  attn_implementation="flash_attention_2",
264
  trust_remote_code=True,
265
+ token=hf_serect,
266
  )
267
  tokenizer = AutoTokenizer.from_pretrained(
268
  model_id,
269
  trust_remote_code=True,
270
+ token=hf_serect,
271
  )
272
 
273
+ waiting_tools_timeout = 5
274
+ supported_tools = json.dumps(
275
+ [
276
+ {
277
+ "type": "function",
278
+ "function": {
279
+ "name": "search_on_internet",
280
+ "description": "Use this tool to search online, only use it for information you don't know or are unsure of, don't abuse it.",
281
+ "parameters": {
282
+ "type": "object",
283
+ "properties": {
284
+ "keyword": {
285
+ "type": "string",
286
+ "description": "Search keywords, rephrase to optimize search results based on questions suitable to the specified search type.",
287
+ "required": True,
288
+ },
289
+ "type": {
290
+ "type": "string",
291
+ "description": "Search type, based on the question to determine whether to search for it in 'wikipedia' or 'google', prefer to use wikipedia for information about events, history and people.",
292
+ "enum": ["wikipedia", "google"],
293
+ "default": "google",
294
+ "required": True,
295
+ },
296
+ },
297
+ },
298
+ },
299
+ }
300
+ ],
301
+ ensure_ascii=False,
302
+ )
303
+
304
+
305
+ @lru_cache(maxsize=128)
306
+ def extract_text_from_webpage(html_content):
307
+ soup = BeautifulSoup(html_content, "html.parser")
308
+ for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
309
+ tag.extract()
310
+ visible_text = soup.get_text(strip=True, separator=" ")
311
+ return visible_text
312
+
313
+
314
+ def search_with_wikipedia(query: str):
315
+ all_results = []
316
+ try:
317
+ all_results.append(wikipedia.summary(query))
318
+ except Exception as e:
319
+ pass
320
+ return all_results
321
+
322
+
323
+ def search_with_google(
324
+ query: str,
325
+ num_results: int = 3,
326
+ timeout: int = 5,
327
+ ssl_verify: bool = None,
328
+ ):
329
+ all_results = []
330
+ max_chars_per_page = 4096
331
+ with requests.Session() as session:
332
+ resp = session.get(
333
+ url="https://www.google.com/search",
334
+ headers={
335
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
336
+ },
337
+ params={
338
+ "q": query,
339
+ "num": num_results,
340
+ "udm": 14,
341
+ },
342
+ timeout=timeout,
343
+ verify=ssl_verify,
344
+ )
345
+ resp.raise_for_status()
346
+ soup = BeautifulSoup(resp.text, "html.parser")
347
+ result_block = soup.find_all("div", attrs={"class": "g"})
348
+ for result in result_block:
349
+ link = result.find("a", href=True)
350
+ if link:
351
+ link = link["href"]
352
+ try:
353
+ webpage = session.get(
354
+ link,
355
+ headers={
356
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
357
+ },
358
+ )
359
+ webpage.raise_for_status()
360
+ visible_text = extract_text_from_webpage(webpage.text)
361
+ if len(visible_text) > max_chars_per_page:
362
+ visible_text = visible_text[:max_chars_per_page]
363
+ all_results.append({"link": link, "text": visible_text})
364
+ except requests.exceptions.RequestException as e:
365
+ print(f"Error fetching or processing {link}: {e}")
366
+ pass
367
+ else:
368
+ pass
369
+ return all_results
370
+
371
 
372
  @spaces.GPU(duration=120)
373
  def generate(
374
  message: str,
375
  chat_history: list[tuple[str, str]],
376
+ allow_used_tools: bool = True,
377
+ system_prompt: str = "",
378
  max_new_tokens: int = 1536,
379
  temperature: float = 0.4,
380
  top_p: float = 0.95,
381
  top_k: int = 50,
382
  repetition_penalty: float = 1.0,
383
  ) -> Iterator[str]:
384
+ # print()
385
+ # print("allow_used_tools:\n", allow_used_tools)
386
+ # print("system_prompt:\n", system_prompt)
387
+ # print("max_new_tokens:\n", max_new_tokens)
388
+ # print("temperature:\n", temperature)
389
+
390
+ def build_input_ids(
391
+ apply_tools: bool = None,
392
+ references: list[str] = None,
393
+ ):
394
+ conversation = []
395
+ if system_prompt:
396
+ conversation.append({"role": "system", "content": system_prompt})
397
+ if apply_tools is True:
398
+ conversation.append({"role": "tools", "content": supported_tools})
399
+ if (
400
+ references is not None
401
+ and isinstance(references, list)
402
+ and len(references) > 0
403
+ ):
404
+ conversation.append(
405
+ {
406
+ "role": "refs",
407
+ "content": json.dumps(references, ensure_ascii=False),
408
+ }
409
+ )
410
+
411
+ for user, assistant in chat_history:
412
+ conversation.extend(
413
+ [
414
+ {"role": "user", "content": user},
415
+ {"role": "assistant", "content": assistant},
416
+ ]
417
+ )
418
+ conversation.append({"role": "user", "content": message})
419
+
420
+ input_ids = tokenizer.apply_chat_template(
421
+ conversation, add_generation_prompt=True, return_tensors="pt"
422
  )
423
+ input_ids = input_ids.to(model.device)
424
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
425
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
426
+ gr.Warning(
427
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
428
+ )
429
+ return input_ids
430
 
431
+ def generate_chat_responses(
432
+ previous_response: str = None,
433
+ ):
434
+ document_references = []
435
+ if previous_response is not None:
436
+ scheduled_tools_runs = None
437
+ try:
438
+ scheduled_tools_runs = json.loads(previous_response)
439
+ if scheduled_tools_runs["type"] == "function" and scheduled_tools_runs[
440
+ "name"
441
+ ] in ["search_on_internet"]:
442
+ pass
443
+ else:
444
+ scheduled_tools_runs = None
445
+ except Exception as e:
446
+ print(e)
447
+ pass
448
+
449
+ if (
450
+ scheduled_tools_runs is not None
451
+ and scheduled_tools_runs["name"] == "search_on_internet"
452
+ ):
453
+ keyword = scheduled_tools_runs["arguments"]["keyword"]
454
+ search_type = scheduled_tools_runs["arguments"]["type"]
455
+ if search_type == "wikipedia":
456
+ gr.Info("Searching for information on the Wikipedia.")
457
+ document_references = search_with_wikipedia(keyword)
458
+ else:
459
+ gr.Info("Searching for information on the Google.")
460
+ document_references = search_with_google(keyword)
461
+
462
+ input_ids = build_input_ids(
463
+ apply_tools=(
464
+ True
465
+ if allow_used_tools is True and previous_response is None
466
+ else False
467
+ ),
468
+ references=document_references,
469
+ )
470
+ streamer = TextIteratorStreamer(
471
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
472
+ )
473
+ generate_kwargs = dict(
474
+ input_ids=input_ids,
475
+ streamer=streamer,
476
+ max_new_tokens=max_new_tokens,
477
+ do_sample=True,
478
+ repetition_penalty=repetition_penalty,
479
  )
480
+ if temperature == 0:
481
+ generate_kwargs["do_sample"] = False
482
+ else:
483
+ generate_kwargs["temperature"] = temperature
484
+ generate_kwargs["top_p"] = top_p
485
+ generate_kwargs["top_k"] = top_k
486
 
487
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
488
+ t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ state = {
491
+ "mark": None,
492
+ "respond": False,
493
+ }
494
+ outputs = []
495
+ for text in streamer:
496
+ if state["mark"] is None:
497
+ state["mark"] = time.time()
498
+ outputs.append(text)
499
+ if state["mark"] + waiting_tools_timeout < time.time():
500
+ state["respond"] = True
501
+ yield "".join(outputs)
502
 
503
+ if (
504
+ state["respond"] is False
505
+ and state["mark"] + waiting_tools_timeout > time.time()
506
+ ):
507
+ gr.Info("Searching for information on the internet.")
508
+ previous_response = "".join(outputs)
509
+ yield from generate_chat_responses(previous_response=previous_response)
510
 
511
+ yield from generate_chat_responses(previous_response=None)
512
 
513
+
514
+ chatbot = gr.Chatbot(
515
+ height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta", show_copy_button=True
516
+ )
517
 
518
  chat_interface = gr.ChatInterface(
519
  fn=generate,
520
  chatbot=chatbot,
521
  fill_height=True,
522
  additional_inputs=[
523
+ gr.Checkbox(
524
+ label="Allow used tools (available: search on internet)", value=True
525
+ ),
526
  gr.Textbox(label="System prompt", lines=6),
527
  gr.Slider(
528
  label="Max new tokens",
 
564
  cache_examples=False,
565
  examples=EXAMPLES,
566
  examples_per_page=9,
567
+ concurrency_limit=100,
568
  )
569
 
570
  with gr.Blocks(fill_height=True, css="style.css") as demo:
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
- gradio==4.37.2
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
 
 
 
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
+ gradio==4.39.0
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
9
+ beautifulsoup4>=4.9
10
+ wikipedia==1.4.0