einfachalf commited on
Commit
a45babd
·
1 Parent(s): b9d2808

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+ import os
4
+
5
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
6
+ HF_TOKEN = os.environ['HF_TOKEN']
7
+
8
+ client = InferenceClient(
9
+ API_URL,
10
+ headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
11
+ )
12
+
13
+ client_unverified = InferenceClient(
14
+ "mistralai/Mistral-7B-Instruct-v0.1"
15
+ )
16
+
17
+ def format_prompt(message, history):
18
+ prompt = "<s>"
19
+ for user_prompt, bot_response in history:
20
+ prompt += f"[INST] {user_prompt} [/INST]"
21
+ prompt += f" {bot_response}</s> "
22
+ prompt += f"[INST] {message} [/INST]"
23
+ return prompt
24
+
25
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
26
+ temperature = float(temperature)
27
+ if temperature < 1e-2:
28
+ temperature = 1e-2
29
+ top_p = float(top_p)
30
+ generate_kwargs = dict(
31
+ temperature=temperature,
32
+ max_new_tokens=max_new_tokens,
33
+ top_p=top_p,
34
+ repetition_penalty=repetition_penalty,
35
+ do_sample=True,
36
+ seed=42,
37
+ )
38
+
39
+ formatted_prompt = format_prompt(prompt, history)
40
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
41
+ output = ""
42
+ for response in stream:
43
+ output += response.token.text
44
+ yield output
45
+ return output
46
+
47
+ additional_inputs=[
48
+ gr.Slider(
49
+ label="Temperature",
50
+ value=0.9,
51
+ minimum=0.0,
52
+ maximum=1.0,
53
+ step=0.05,
54
+ interactive=True,
55
+ info="Higher values produce more diverse outputs",
56
+ ),
57
+ gr.Slider(
58
+ label="Max new tokens",
59
+ value=256,
60
+ minimum=0,
61
+ maximum=1048,
62
+ step=64,
63
+ interactive=True,
64
+ info="The maximum numbers of new tokens",
65
+ ),
66
+ gr.Slider(
67
+ label="Top-p (nucleus sampling)",
68
+ value=0.90,
69
+ minimum=0.0,
70
+ maximum=1,
71
+ step=0.05,
72
+ interactive=True,
73
+ info="Higher values sample more low-probability tokens",
74
+ ),
75
+ gr.Slider(
76
+ label="Repetition penalty",
77
+ value=1.2,
78
+ minimum=1.0,
79
+ maximum=2.0,
80
+ step=0.05,
81
+ interactive=True,
82
+ info="Penalize repeated tokens",
83
+ )
84
+ ]
85
+
86
+ css = """
87
+ #mkd {
88
+ height: 500px;
89
+ overflow: auto;
90
+ border: 1px solid #ccc;
91
+ }
92
+ """
93
+
94
+ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
95
+ gr.HTML("<h1><center>MistralTalk<h1><center>")
96
+ gr.HTML("<h3><center>💬<h3><center>")
97
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
98
+ gr.ChatInterface(
99
+ generate,
100
+ additional_inputs=additional_inputs,
101
+ examples=[["What is the secret to life?"], ["How the universe works?"],["What can you do?"]]
102
+ )
103
+
104
+ demo.queue(max_size=100).launch(debug=True)