JoPmt commited on
Commit
eedd6da
·
1 Parent(s): 4a239fa

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (22).py +103 -0
  2. requirements (76).txt +2 -0
app (22).py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from text_generation import Client, InferenceAPIClient
4
+
5
+ cntxt = (
6
+ "\nUser: Hi!\nAssistant: I'm Mistral, a 7 Billion token model checkpoint, my version is 0.1, part of a larger set of checkpoints "
7
+ "trained by MistralAI! I was created to be an excellent expert assistant capable of carefully, logically, truthfully, methodically fulfilling any User request."
8
+ "I'm capable of acting as an expert AI Writing model, acting as an expert AI Programming model, acting as an expert AI Medical model, acting as an expert AI Legal model and much more... "
9
+ "I'm programmed to be helpful, polite, honest, and friendly.\n"
10
+ )
11
+
12
+ def get_client(model: str):
13
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
14
+
15
+ def get_usernames(model: str):
16
+ """
17
+ Returns:
18
+ (str, str, str, str): pre-prompt, username, bot name, separator
19
+ """
20
+ return cntxt, "User: ", "Assistant: ", "\n"
21
+
22
+ def predict(model: str,inputs: str,typical_p: float,top_p: float,temperature: float,top_k: int,repetition_penalty: float,watermark: bool,chatbot,history,):
23
+ client = get_client(model)
24
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
25
+ history.append(inputs)
26
+ past = []
27
+ for data in chatbot:
28
+ user_data, model_data = data
29
+ if not user_data.startswith(user_name):
30
+ user_data = user_name + user_data
31
+ if not model_data.startswith(sep + assistant_name):
32
+ model_data = sep + assistant_name + model_data
33
+ past.append(user_data + model_data.rstrip() + sep)
34
+
35
+ if not inputs.startswith(user_name):
36
+ inputs = user_name + inputs
37
+
38
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
39
+ partial_words = ""
40
+
41
+ if model in ("mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.1"):
42
+ iterator = client.generate_stream(total_inputs,typical_p=typical_p,truncate=1024,watermark=False,max_new_tokens=1024,)
43
+
44
+ for i, response in enumerate(iterator):
45
+ if response.token.special:
46
+ continue
47
+
48
+ partial_words = partial_words + response.token.text
49
+ if partial_words.endswith(user_name.rstrip()):
50
+ partial_words = partial_words.rstrip(user_name.rstrip())
51
+ if partial_words.endswith(assistant_name.rstrip()):
52
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
53
+
54
+ if i == 0:
55
+ history.append(" " + partial_words)
56
+ elif response.token.text not in user_name:
57
+ history[-1] = partial_words
58
+
59
+ chat = [
60
+ (history[i].strip(), history[i + 1].strip())
61
+ for i in range(0, len(history) - 1, 2)
62
+ ]
63
+ yield chat, history
64
+
65
+ def reset_textbox():
66
+ return gr.update(value="")
67
+
68
+ def radio_on_change(value: str,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,):
69
+ if model in ("mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.1"):
70
+ typical_p = typical_p.update(value=0.2, visible=True)
71
+ top_p = top_p.update(visible=False)
72
+ top_k = top_k.update(visible=False)
73
+ temperature = temperature.update(visible=False)
74
+ repetition_penalty = repetition_penalty.update(visible=False)
75
+ watermark = watermark.update(False)
76
+ return (typical_p,top_p,top_k,temperature,repetition_penalty,watermark,)
77
+
78
+ with gr.Blocks(
79
+ css="""#col_container {margin-left: auto; margin-right: auto;}
80
+ #chatbot {height: 520px; overflow: auto;}"""
81
+ ) as demo:
82
+ with gr.Column(elem_id="col_container"):
83
+ model = gr.Radio(value="mistralai/Mistral-7B-v0.1",choices=["mistralai/Mistral-7B-v0.1",],label="Model",visible=False,)
84
+ chatbot = gr.Chatbot(elem_id="chatbot")
85
+ inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter")
86
+ state = gr.State([])
87
+ b1 = gr.Button()
88
+
89
+ with gr.Accordion("Parameters", open=False, visible=False):
90
+ typical_p = gr.Slider(minimum=-0,maximum=1.0,value=0.2,step=0.05,interactive=True,label="Typical P mass",)
91
+ top_p = gr.Slider(minimum=-0,maximum=1.0,value=0.25,step=0.05,interactive=True,label="Top-p (nucleus sampling)",visible=False,)
92
+ temperature = gr.Slider(minimum=-0,maximum=5.0,value=0.6,step=0.1,interactive=True,label="Temperature",visible=False,)
93
+ top_k = gr.Slider(minimum=1,maximum=50,value=50,step=1,interactive=True,label="Top-k",visible=False,)
94
+ repetition_penalty = gr.Slider(minimum=0.1,maximum=3.0,value=1.03,step=0.01,interactive=True,label="Repetition Penalty",visible=False,)
95
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
96
+
97
+ model.change(lambda value: radio_on_change(value,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,),inputs=model,outputs=[typical_p,top_p,top_k,temperature,repetition_penalty,watermark,],)
98
+ inputs.submit(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],)
99
+ b1.click(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],)
100
+ b1.click(reset_textbox, [], [inputs])
101
+ inputs.submit(reset_textbox, [], [inputs])
102
+
103
+ demo.queue(max_size=1).launch(max_threads=1,)
requirements (76).txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ text-generation==0.5.0
2
+ gradio==3.20.1