Ctaake commited on
Commit
194241b
·
verified ·
1 Parent(s): 798d29f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer
4
+ import random
5
+
6
+ # Model which is used
7
+ checkpoint = "mistralai/Mistral-7B-Instruct-v0.2"
8
+ # Tokenizer to convert into the right format
9
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
10
+ # Inference client with the model (And HF-token if needed)
11
+ client = InferenceClient(checkpoint)
12
+
13
+
14
+ def format_prompt(message, history, systemPrompt):
15
+ # Adjusting the format to fit the currently selected model
16
+ # First everything is converted into role format
17
+ # First a system prompt
18
+ messages = []
19
+ messages.append({"role": "user", "content": systemPrompt})
20
+ messages.append({"role": "assistant", "content": ""})
21
+ # Followed by the message history
22
+ for user_message, bot_message in history:
23
+ messages.append({"role": "user", "content": user_message})
24
+ messages.append({"role": "assistant", "content": bot_message})
25
+ # Followed by the current message
26
+ messages.append({"role": "user", "content": message})
27
+ # The tokenizer converts into the model format
28
+ messages = tokenizer.apply_chat_template(messages, tokenize=False)
29
+ return messages
30
+
31
+
32
+ def inference(message, history, systemPrompt, temperature, maxTokens, topP, repPenalty):
33
+ # Updating the settings for the generation
34
+ client_settings = dict(
35
+ temperature=temperature,
36
+ max_new_tokens=maxTokens,
37
+ top_p=topP,
38
+ repetition_penalty=repPenalty,
39
+ do_sample=True,
40
+ stream=True,
41
+ details=True,
42
+ return_full_text=False,
43
+ seed=random.randint(0, 999999999),
44
+ )
45
+ # Generating the response by passing the prompt in right format plus the client settings
46
+ stream = client.text_generation(format_prompt(message, history, systemPrompt),
47
+ **client_settings)
48
+ # Reading the stream
49
+ partial_response = ""
50
+ for stream_part in stream:
51
+ partial_response += stream_part.token.text
52
+ yield partial_response
53
+
54
+
55
+ myAdditionalInputs = [
56
+ gr.Textbox(
57
+ label="System Prompt",
58
+ max_lines=500,
59
+ lines=10,
60
+ interactive=True,
61
+ value="You are a friendly girl who doesn't answer unnecessarily long."
62
+ ),
63
+ gr.Slider(
64
+ label="Temperature",
65
+ value=0.9,
66
+ minimum=0.0,
67
+ maximum=1.0,
68
+ step=0.05,
69
+ interactive=True,
70
+ info="Higher values produce more diverse outputs",
71
+ ),
72
+ gr.Slider(
73
+ label="Max new tokens",
74
+ value=256,
75
+ minimum=0,
76
+ maximum=1048,
77
+ step=64,
78
+ interactive=True,
79
+ info="The maximum numbers of new tokens",
80
+ ),
81
+ gr.Slider(
82
+ label="Top-p (nucleus sampling)",
83
+ value=0.9,
84
+ minimum=0.0,
85
+ maximum=1,
86
+ step=0.05,
87
+ interactive=True,
88
+ info="Higher values sample more low-probability tokens",
89
+ ),
90
+ gr.Slider(
91
+ label="Repetition penalty",
92
+ value=1.1,
93
+ minimum=1.0,
94
+ maximum=2.0,
95
+ step=0.05,
96
+ interactive=True,
97
+ info="Penalize repeated tokens",
98
+ )
99
+ ]
100
+
101
+ myChatbot = gr.Chatbot(avatar_images=["./ava_m.png", "./ava_f.png"],
102
+ bubble_full_width=False,
103
+ show_label=False,
104
+ show_copy_button=False,
105
+ likeable=False)
106
+
107
+ myTextInput = gr.Textbox(lines=2,
108
+ max_lines=2,
109
+ placeholder="Send a message",
110
+ container=False,
111
+ scale=7)
112
+
113
+ myTheme = gr.themes.Soft(primary_hue=gr.themes.colors.fuchsia,
114
+ secondary_hue=gr.themes.colors.fuchsia,
115
+ spacing_size="sm",
116
+ radius_size="md")
117
+
118
+ mySubmitButton = gr.Button(value="SEND",
119
+ variant='primary')
120
+ myRetryButton = gr.Button(value="RETRY",
121
+ variant='secondary',
122
+ size="sm")
123
+ myUndoButton = gr.Button(value="UNDO",
124
+ variant='secondary',
125
+ size="sm")
126
+ myClearButton = gr.Button(value="CLEAR",
127
+ variant='secondary',
128
+ size="sm")
129
+
130
+
131
+ gr.ChatInterface(
132
+ inference,
133
+ chatbot=myChatbot,
134
+ textbox=myTextInput,
135
+ title="My chat bot",
136
+ theme=myTheme,
137
+ additional_inputs=myAdditionalInputs,
138
+ submit_btn=mySubmitButton,
139
+ stop_btn="STOP",
140
+ retry_btn=myRetryButton,
141
+ undo_btn=myUndoButton,
142
+ clear_btn=myClearButton,
143
+ ).queue().launch(show_api=False)