salomonsky commited on
Commit
02f697f
·
verified ·
1 Parent(s): 8f514e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import gradio as gr
 
2
  from huggingface_hub import InferenceClient
3
- import concurrent.futures
 
4
 
5
  system_prompt = ""
6
  system_prompt_sent = False
@@ -36,6 +39,10 @@ def generate(prompt, history, temperature=0.9, max_new_tokens=4096, top_p=0.95,
36
  )
37
 
38
  formatted_prompt = format_prompt(prompt, history)
 
 
 
 
39
 
40
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
41
  output = ""
@@ -44,17 +51,14 @@ def generate(prompt, history, temperature=0.9, max_new_tokens=4096, top_p=0.95,
44
  output += response.token.text
45
  yield output
46
 
 
 
47
  return output
48
 
49
- def run_chatbot(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty):
50
- global client
51
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
52
- with concurrent.futures.ProcessPoolExecutor() as executor:
53
- result = executor.submit(generate, prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
54
- return result
55
 
56
  chat_interface = gr.ChatInterface(
57
- fn=run_chatbot,
58
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=False, likeable=False, layout="vertical", height=900),
59
  concurrency_limit=9,
60
  theme="soft",
 
1
+ import concurrent.futures
2
  import gradio as gr
3
+ from dogpile.cache import make_region
4
  from huggingface_hub import InferenceClient
5
+
6
+ cache = make_region().configure('dogpile.cache.memory', thread_local=True)
7
 
8
  system_prompt = ""
9
  system_prompt_sent = False
 
39
  )
40
 
41
  formatted_prompt = format_prompt(prompt, history)
42
+ cache_key = f"generate:{formatted_prompt}:{temperature}:{max_new_tokens}:{top_p}:{repetition_penalty}"
43
+ cached_response = cache.get(cache_key)
44
+ if cached_response is not None:
45
+ return cached_response
46
 
47
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
48
  output = ""
 
51
  output += response.token.text
52
  yield output
53
 
54
+ cache.set(cache_key, output)
55
+
56
  return output
57
 
58
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
 
 
59
 
60
  chat_interface = gr.ChatInterface(
61
+ fn=generate,
62
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=False, likeable=False, layout="vertical", height=900),
63
  concurrency_limit=9,
64
  theme="soft",